1189 lines
54 KiB
Python
1189 lines
54 KiB
Python
from threading import Thread
|
||
import os
|
||
import time
|
||
import logging
|
||
import threading
|
||
import socket
|
||
import sys
|
||
from datetime import timedelta
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import pandas as pd
|
||
import tushare as ts
|
||
from log_style import setup_logger, get_logger
|
||
from functools import lru_cache
|
||
import argparse
|
||
|
||
# 配置日志记录器
|
||
logger = setup_logger(
|
||
name='update_tushare_totxt',
|
||
log_file='app.log',
|
||
level=logging.INFO,
|
||
log_format='%(asctime)s - %(levelname)s - %(message)s',
|
||
console=True
|
||
)
|
||
|
||
|
||
|
||
def log_warning(message, **kwargs):
|
||
"""警告信息,黄色显示"""
|
||
logger.warning(f"\033[33m{message}\033[0m", **kwargs)
|
||
|
||
def log_failure(message, **kwargs):
|
||
"""失败信息,红色显示"""
|
||
logger.error(f"\033[31m{message}\033[0m", **kwargs)
|
||
|
||
|
||
|
||
# 定义带颜色的日志函数
|
||
def log_success(message, **kwargs):
|
||
"""成功信息,绿色显示"""
|
||
logger.info(f"\033[32m{message}\033[0m", **kwargs)
|
||
|
||
def log_failure(message, **kwargs):
|
||
"""失败信息,红色显示"""
|
||
logger.error(f"\033[31m{message}\033[0m", **kwargs)
|
||
|
||
def log_process(message, **kwargs):
|
||
"""处理中信息,黄色显示"""
|
||
logger.info(f"\033[33m{message}\033[0m", **kwargs)
|
||
|
||
def log_result(message, **kwargs):
|
||
"""结果信息,蓝色显示"""
|
||
logger.info(f"\033[34m{message}\033[0m", **kwargs)
|
||
|
||
def log_message(message, **kwargs):
|
||
"""普通信息,白色显示"""
|
||
logger.info(f"\033[37m{message}\033[0m", **kwargs)
|
||
# 数据库相关导入
|
||
import pymysql
|
||
from sqlalchemy import create_engine, text
|
||
|
||
class Config:
|
||
"""全局配置类"""
|
||
BASE_DIR = 'D:\\gp_data'
|
||
INPUT_FILE = os.path.join(BASE_DIR, 'code', 'all_stock_codes.txt')
|
||
OUTPUT_DIR = os.path.join(BASE_DIR, 'day')
|
||
INDEX_DIR = os.path.join(BASE_DIR, 'index')
|
||
# 最大线程数,可通过环境变量MAX_THREADS设置,默认10
|
||
MAX_THREADS = int(os.getenv('MAX_THREADS', 20))
|
||
# 每分钟最大请求次数限制,可通过环境变量REQUEST_LIMIT设置,默认200
|
||
REQUEST_LIMIT = int(os.getenv('REQUEST_LIMIT', 500))
|
||
# 请求失败时的最大重试次数
|
||
MAX_RETRIES = 3
|
||
# 时间常量,表示60秒(1分钟),用于请求频率控制
|
||
ONE_MINUTE = 60
|
||
# 多账户Token数组,从环境变量读取,如果没有则使用默认值
|
||
ACCOUNTS = [
|
||
os.getenv('TUSHARE_TOKEN1', '9343e641869058684afeadfcfe7fd6684160852e52e85332a7734c8d'), # 主账户
|
||
os.getenv('TUSHARE_TOKEN2', 'b156282fb2328afa056346736a93778861ac6be97a485f4d1f55c493') # 备用账户1
|
||
]
|
||
# 新增数据库重试配置
|
||
DB_MAX_RETRIES = 3 # 最大重试次数
|
||
DB_RETRY_INTERVAL = 3 # 重试间隔(秒)
|
||
# 新增数据库配置,从环境变量读取敏感信息
|
||
DB_CONFIG = {
|
||
'host': os.getenv('DB_HOST', '127.0.0.1'),
|
||
'port': int(os.getenv('DB_PORT', 3306)),
|
||
'user': os.getenv('DB_USER', 'root'),
|
||
'password': os.getenv('DB_PASSWORD', 'stella0850'),
|
||
'database': os.getenv('DB_NAME', 'stock_data')
|
||
}
|
||
|
||
# 数据库连接池配置
|
||
DB_POOL_SIZE = 20 # 连接池大小
|
||
DB_MAX_OVERFLOW = 10 # 最大溢出连接数
|
||
|
||
|
||
class AccountManager:
|
||
# 类变量缓存,实现跨实例的缓存共享
|
||
_stock_basic_cache = None
|
||
_stock_basic_cache_time = 0
|
||
|
||
def __init__(self):
|
||
"""初始化账户管理器"""
|
||
self.accounts = [ts.pro_api(token) for token in Config.ACCOUNTS]
|
||
self.current_index = 0
|
||
self.lock = threading.Lock()
|
||
# 为每个账户单独维护请求时间,提高并发效率
|
||
self.last_request_time_per_account = [time.time() for _ in self.accounts]
|
||
|
||
def get_stock_basic(self, force_primary=False):
|
||
"""获取股票基本信息,带缓存机制"""
|
||
current_time = time.time()
|
||
|
||
# 使用类变量缓存,实现跨实例的缓存共享
|
||
# 缓存有效期为30分钟
|
||
if (AccountManager._stock_basic_cache is not None and
|
||
current_time - AccountManager._stock_basic_cache_time < 30 * 60):
|
||
return AccountManager._stock_basic_cache
|
||
|
||
# 缓存失效,重新获取
|
||
pro = self.get_next_account(force_primary)
|
||
df_sse = pro.stock_basic(exchange='SSE', list_status='L')
|
||
df_szse = pro.stock_basic(exchange='SZSE', list_status='L')
|
||
|
||
AccountManager._stock_basic_cache = pd.concat([df_sse, df_szse], ignore_index=True)
|
||
AccountManager._stock_basic_cache_time = current_time
|
||
|
||
return AccountManager._stock_basic_cache
|
||
|
||
def get_next_account(self, force_primary=False):
|
||
"""轮询获取下一个账户
|
||
force_primary: 强制使用主账户
|
||
"""
|
||
with self.lock:
|
||
try:
|
||
if force_primary:
|
||
return self.accounts[0] # 强制返回主账户
|
||
account = self.accounts[self.current_index]
|
||
self.current_index = (self.current_index + 1) % len(self.accounts)
|
||
return account
|
||
except (ConnectionError, TimeoutError) as e:
|
||
# 网络连接相关异常
|
||
log_failure(f"账户网络连接异常: {str(e)}")
|
||
except Exception as e:
|
||
# 其他异常
|
||
log_failure(f"获取账户异常: {str(e)}")
|
||
|
||
# 无论什么异常,都尝试下一个账户
|
||
self.current_index = (self.current_index + 1) % len(self.accounts)
|
||
|
||
# 防止无限递归,设置最大尝试次数
|
||
retry_count = getattr(self, '_retry_count', 0)
|
||
if retry_count >= len(self.accounts) * 2: # 尝试每个账户2次
|
||
setattr(self, '_retry_count', 0)
|
||
log_failure("所有账户都尝试失败,返回最后一个账户作为尝试")
|
||
return self.accounts[0] # 返回第一个账户作为最后的尝试
|
||
|
||
setattr(self, '_retry_count', retry_count + 1)
|
||
return self.get_next_account(force_primary)
|
||
|
||
class DataDownloader:
|
||
"""数据下载器核心类"""
|
||
|
||
def __init__(self):
|
||
self.account_manager = AccountManager()
|
||
self.setup_logging()
|
||
self.create_dirs()
|
||
self.last_request_time = time.time()
|
||
self.db_conn = None # 新增数据库连接属性
|
||
self.total_files = 0
|
||
self.processed_files = 0
|
||
self.progress = {'current': 0, 'total': 100, 'message': '准备开始'}
|
||
# 数据库连接池管理
|
||
self.conn_pool = {}
|
||
self.conn_count = 0
|
||
self.max_connections = 5 # 最大连接数限制
|
||
self.active_connections = set()
|
||
|
||
def show_progress(self, current, total, start_time):
|
||
"""进度显示"""
|
||
self.progress = {
|
||
'current': current,
|
||
'total': total,
|
||
'message': f'处理中: {current / total * 100:.1f}% | 耗时: {time.time() - start_time:.1f}s'
|
||
}
|
||
progress = current / total * 100
|
||
elapsed = time.time() - start_time
|
||
log_process(f"\r处理中: {progress:.1f}% ({current}/{total}) | 耗时: {elapsed:.1f}s", end='', flush=True)
|
||
if current == total:
|
||
log_success(f"\n数据处理完成!") # 添加换行符和完成信息
|
||
|
||
# def show_progress(self, current, total, start_time):
|
||
# """更频繁的进度显示"""
|
||
# if current == total or current % 2 == 0:
|
||
# progress = current / total * 100
|
||
# elapsed = time.time() - start_time
|
||
# print(f"\r处理中: {progress:.1f}% ({current}/{total}) | 耗时: {elapsed:.1f}s", end='')
|
||
# if current == total:
|
||
# print(" 完成!")
|
||
|
||
def setup_logging(self):
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.StreamHandler(),
|
||
logging.FileHandler('app.log')
|
||
]
|
||
)
|
||
|
||
def connect_db(self):
|
||
"""连接数据库,带重试机制和自动创建数据库功能"""
|
||
# 检查连接池是否有可用连接
|
||
if self.conn_pool:
|
||
# 使用第一个可用连接
|
||
conn_id, conn = next(iter(self.conn_pool.items()))
|
||
try:
|
||
# 测试连接是否有效
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT 1")
|
||
cursor.close()
|
||
del self.conn_pool[conn_id]
|
||
self.active_connections.add(conn_id)
|
||
log_message(f"复用数据库连接 {conn_id}")
|
||
self.db_conn = conn
|
||
return True
|
||
except Exception:
|
||
# 连接无效,移除并继续创建新连接
|
||
try:
|
||
conn.close()
|
||
except:
|
||
pass
|
||
del self.conn_pool[conn_id]
|
||
log_warning(f"移除无效连接 {conn_id}")
|
||
|
||
def return_connection(self, conn=None):
|
||
"""返回连接到连接池或关闭连接"""
|
||
# 如果提供了连接,使用提供的连接
|
||
if conn:
|
||
try:
|
||
# 测试连接是否仍然有效
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT 1")
|
||
cursor.close()
|
||
# 生成唯一连接ID
|
||
conn_id = f"conn_{int(time.time() * 1000)}"
|
||
# 从活动连接中移除并添加到连接池
|
||
if hasattr(self, 'active_connections'):
|
||
for active_conn_id in list(self.active_connections):
|
||
if active_conn_id.endswith(str(hash(str(conn)))):
|
||
self.active_connections.remove(active_conn_id)
|
||
# 添加到连接池
|
||
self.conn_pool[conn_id] = conn
|
||
log_message(f"连接已返回连接池: {conn_id}")
|
||
return True
|
||
except Exception:
|
||
# 连接已无效,关闭它
|
||
try:
|
||
conn.close()
|
||
if hasattr(self, 'conn_count'):
|
||
self.conn_count = max(0, self.conn_count - 1)
|
||
log_message(f"无效连接已关闭,当前连接数: {self.conn_count}")
|
||
except:
|
||
pass
|
||
return False
|
||
# 如果使用的是对象的默认连接
|
||
elif hasattr(self, 'db_conn') and self.db_conn:
|
||
try:
|
||
# 测试连接是否仍然有效
|
||
cursor = self.db_conn.cursor()
|
||
cursor.execute("SELECT 1")
|
||
cursor.close()
|
||
# 生成唯一连接ID
|
||
conn_id = f"conn_{int(time.time() * 1000)}"
|
||
# 从活动连接中移除并添加到连接池
|
||
for active_conn_id in list(self.active_connections):
|
||
self.active_connections.remove(active_conn_id)
|
||
# 添加到连接池
|
||
self.conn_pool[conn_id] = self.db_conn
|
||
self.db_conn = None # 清除对象属性
|
||
log_message(f"默认连接已返回连接池: {conn_id}")
|
||
return True
|
||
except Exception:
|
||
# 连接已无效,关闭它
|
||
try:
|
||
self.db_conn.close()
|
||
self.db_conn = None
|
||
self.conn_count = max(0, self.conn_count - 1)
|
||
log_message(f"默认无效连接已关闭,当前连接数: {self.conn_count}")
|
||
except:
|
||
pass
|
||
return False
|
||
return False
|
||
|
||
# 如果连接数超过限制,等待一段时间
|
||
while self.conn_count >= self.max_connections:
|
||
log_warning(f"数据库连接数达到上限 {self.max_connections},等待释放...")
|
||
time.sleep(1)
|
||
|
||
for attempt in range(Config.DB_MAX_RETRIES):
|
||
try:
|
||
# 先尝试连接指定数据库
|
||
conn = pymysql.connect(**Config.DB_CONFIG)
|
||
conn_id = f"conn_{int(time.time() * 1000)}"
|
||
self.conn_count += 1
|
||
self.active_connections.add(conn_id)
|
||
log_message(f"创建数据库连接 {conn_id},当前连接数: {self.conn_count}")
|
||
self.db_conn = conn
|
||
return True # 静默返回,不记录成功日志
|
||
except pymysql.err.OperationalError as e:
|
||
if e.args[0] == 1049: # 数据库不存在错误码
|
||
try:
|
||
# 连接MySQL服务器(不带数据库名)
|
||
temp_config = Config.DB_CONFIG.copy()
|
||
temp_config.pop('database')
|
||
temp_conn = pymysql.connect(**temp_config)
|
||
|
||
# 创建数据库(使用Config.DB_CONFIG中的database名)
|
||
with temp_conn.cursor() as cursor:
|
||
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {Config.DB_CONFIG['database']} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
|
||
temp_conn.commit()
|
||
temp_conn.close()
|
||
|
||
log_message(f"数据库 {Config.DB_CONFIG['database']} 创建成功")
|
||
# 重新尝试连接
|
||
continue
|
||
except Exception as create_e:
|
||
log_failure(f"创建数据库失败: {create_e}")
|
||
return False
|
||
elif attempt < Config.DB_MAX_RETRIES - 1: # 不是最后一次尝试
|
||
log_warning(f"数据库连接失败(尝试 {attempt + 1}/{Config.DB_MAX_RETRIES}): {e}")
|
||
time.sleep(Config.DB_RETRY_INTERVAL)
|
||
else:
|
||
log_failure(f"数据库连接失败,已达到最大重试次数 {Config.DB_MAX_RETRIES}")
|
||
return False
|
||
except Exception as e:
|
||
if attempt < Config.DB_MAX_RETRIES - 1: # 不是最后一次尝试
|
||
log_warning(f"数据库连接失败(尝试 {attempt + 1}/{Config.DB_MAX_RETRIES}): {e}")
|
||
time.sleep(Config.DB_RETRY_INTERVAL)
|
||
else:
|
||
log_failure(f"数据库连接失败,已达到最大重试次数 {Config.DB_MAX_RETRIES}")
|
||
return False
|
||
|
||
def update_to_db(self, data: pd.DataFrame, table_name: str):
|
||
"""线程安全的数据库更新方法"""
|
||
engine = None
|
||
try:
|
||
# 每个线程创建自己的引擎和连接,使用更好的连接池配置
|
||
engine = create_engine(
|
||
f"mysql+pymysql://{Config.DB_CONFIG['user']}:{Config.DB_CONFIG['password']}@"
|
||
f"{Config.DB_CONFIG['host']}:{Config.DB_CONFIG['port']}/{Config.DB_CONFIG['database']}?charset=utf8mb4",
|
||
pool_size=Config.MAX_THREADS, # 设置连接池大小
|
||
max_overflow=Config.DB_MAX_OVERFLOW,
|
||
pool_timeout=30, # 连接超时时间
|
||
pool_recycle=3600, # 连接回收时间
|
||
pool_pre_ping=True, # 连接前检查有效性
|
||
echo=False,
|
||
connect_args={
|
||
"charset": "utf8mb4",
|
||
"autocommit": True, # 自动提交
|
||
"connect_timeout": 30,
|
||
"read_timeout": 30,
|
||
"write_timeout": 30
|
||
}
|
||
)
|
||
|
||
# 强制表名小写并移除特殊字符
|
||
table_name = table_name.lower().replace('.', '_')
|
||
|
||
# 添加表结构检查
|
||
if table_name.startswith('stock_'):
|
||
if 'trade_date' not in data.columns:
|
||
log_failure(f"数据缺少必要列 trade_date")
|
||
return False
|
||
if not pd.api.types.is_datetime64_any_dtype(data['trade_date']):
|
||
data['trade_date'] = pd.to_datetime(data['trade_date'], format='%Y%m%d')
|
||
|
||
with engine.begin() as connection: # 使用事务
|
||
data.to_sql(table_name, connection, if_exists='append', index=False, chunksize=1000)
|
||
return True
|
||
except Exception as e:
|
||
log_failure(f"数据库更新失败: {e}")
|
||
return False
|
||
finally:
|
||
# 确保资源释放
|
||
if engine:
|
||
try:
|
||
engine.dispose()
|
||
log_message(f"SQLAlchemy引擎已释放")
|
||
except Exception as e:
|
||
log_failure(f"释放SQLAlchemy引擎失败: {e}")
|
||
|
||
|
||
def update_database(self, progress_queue=None, completion_label=None):
|
||
try:
|
||
log_process("开始更新数据库...")
|
||
start_time = time.time()
|
||
|
||
# 1. 更新股票代码表(使用缓存机制)
|
||
log_process("正在更新股票代码表...")
|
||
df_all = self.account_manager.get_stock_basic(force_primary=True)
|
||
self.update_to_db(df_all, 'stock_basic')
|
||
|
||
# 2. 从本地文件更新日线数据
|
||
log_process("正在更新个股日线数据...")
|
||
files = [f for f in os.listdir(Config.OUTPUT_DIR) if f.endswith('_daily_data.txt')]
|
||
self.total_files = len(files)
|
||
self.processed_files = 0
|
||
|
||
if self.total_files > 0:
|
||
# 使用批量插入优化数据库写入
|
||
self.batch_update_stock_data(files, start_time)
|
||
else:
|
||
log_message("没有找到需要处理的个股数据文件")
|
||
|
||
# 初始显示进度0%
|
||
# self.show_progress(0, self.total_files, start_time)
|
||
#
|
||
# # 使用批量插入优化数据库写入
|
||
# self.batch_update_stock_data(files, start_time)
|
||
|
||
# for i, file in enumerate(files, 1):
|
||
# file_path = os.path.join(Config.OUTPUT_DIR, file)
|
||
# df = self.read_from_txt(file_path)
|
||
# if df is not None:
|
||
# code = file.split('_')[0].lower().replace('.', '_')
|
||
# self.update_to_db(df, f'stock_{code}')
|
||
#
|
||
# self.processed_files = i
|
||
# self.show_progress(i, self.total_files, start_time)
|
||
|
||
# 3. 更新指数数据(从本地文件)
|
||
log_process("正在更新指数数据...")
|
||
index_file = os.path.join(Config.INDEX_DIR, '000001.SH.txt')
|
||
if os.path.exists(index_file):
|
||
df_index = self.read_from_txt(index_file)
|
||
if df_index is not None:
|
||
self.update_to_db(df_index, 'index_daily')
|
||
|
||
# 4. 更新个股数据(分批处理)
|
||
log_process("正在更新实时个股数据...")
|
||
stock_codes = df_all['ts_code'].tolist()
|
||
batch_size = 400
|
||
for i in range(0, len(stock_codes), batch_size):
|
||
batch = stock_codes[i:i + batch_size]
|
||
df = self.fetch_data_with_retry(pro.daily, ts_code=','.join(batch))
|
||
if df is not None:
|
||
self.update_to_db(df, 'stock_daily')
|
||
log_success(f"数据库更新完成! 总耗时: {time.time() - start_time:.2f}秒")
|
||
if completion_label:
|
||
completion_label.config(text="数据库更新完成!", foreground="green")
|
||
if progress_queue:
|
||
progress_queue.put(1)
|
||
return True
|
||
except Exception as e:
|
||
log_failure(f"数据库更新失败: {e}")
|
||
if completion_label:
|
||
completion_label.config(text="数据库更新失败!", foreground="red")
|
||
return False
|
||
|
||
def batch_update_stock_data(self, files, start_time):
|
||
"""批量更新股票数据到数据库以提高性能 - 优化版本"""
|
||
engine = None
|
||
try:
|
||
# 创建一个共享的数据库引擎,使用更好的连接池配置
|
||
engine = create_engine(
|
||
f"mysql+pymysql://{Config.DB_CONFIG['user']}:{Config.DB_CONFIG['password']}@"
|
||
f"{Config.DB_CONFIG['host']}:{Config.DB_CONFIG['port']}/{Config.DB_CONFIG['database']}?charset=utf8mb4",
|
||
pool_size=Config.MAX_THREADS,
|
||
max_overflow=Config.DB_MAX_OVERFLOW,
|
||
pool_timeout=30,
|
||
pool_recycle=3600,
|
||
pool_pre_ping=True,
|
||
echo=False # 关闭SQL日志以提高性能
|
||
)
|
||
|
||
# 批量处理文件,每批处理100个文件
|
||
batch_size = 100
|
||
processed_count = 0
|
||
total_files = len(files)
|
||
|
||
for i in range(0, len(files), batch_size):
|
||
batch_files = files[i:i + batch_size]
|
||
all_data = []
|
||
|
||
# 收集一批数据
|
||
for file in batch_files:
|
||
file_path = os.path.join(Config.OUTPUT_DIR, file)
|
||
df = self.read_from_txt(file_path)
|
||
if df is not None:
|
||
code = file.split('_')[0].lower().replace('.', '_')
|
||
df['stock_code'] = code # 添加股票代码标识
|
||
all_data.append(df)
|
||
# 更新单个文件处理进度
|
||
processed_count += 1
|
||
self.show_progress(processed_count, total_files, start_time)
|
||
|
||
# 合并数据并批量插入
|
||
if all_data:
|
||
combined_df = pd.concat(all_data, ignore_index=True)
|
||
# 使用事务批量插入
|
||
with engine.begin() as connection:
|
||
# 先删除重复数据(如果有需要)
|
||
# 然后批量插入新数据
|
||
combined_df.to_sql('stock_data_combined', connection, if_exists='append', index=False,
|
||
chunksize=2000, method='multi')
|
||
# 确保最终进度显示为100%
|
||
if total_files > 0:
|
||
self.show_progress(total_files, total_files, start_time)
|
||
|
||
except Exception as e:
|
||
log_failure(f"批量更新股票数据失败: {e}")
|
||
finally:
|
||
# 确保资源释放
|
||
if engine:
|
||
try:
|
||
engine.dispose()
|
||
log_message(f"批量更新的SQLAlchemy引擎已释放")
|
||
except Exception as e:
|
||
log_failure(f"释放批量更新的SQLAlchemy引擎失败: {e}")
|
||
|
||
def create_dirs(self):
|
||
# 创建基础目录
|
||
os.makedirs(Config.BASE_DIR, exist_ok=True)
|
||
# 创建代码文件目录
|
||
os.makedirs(os.path.dirname(Config.INPUT_FILE), exist_ok=True)
|
||
# 创建输出目录
|
||
os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
|
||
# 创建指数目录
|
||
os.makedirs(Config.INDEX_DIR, exist_ok=True)
|
||
|
||
# 优化 fetch_data_with_retry 方法,改进请求频率控制
|
||
def fetch_data_with_retry(self, func, *args, **kwargs):
|
||
"""带重试机制的数据获取方法 - 优化版本"""
|
||
# 计算请求频率限制(总请求数/分钟)
|
||
total_requests_per_minute = Config.REQUEST_LIMIT * len(Config.ACCOUNTS)
|
||
requests_per_second = total_requests_per_minute / 60.0
|
||
min_interval = 1.0 / requests_per_second
|
||
|
||
for attempt in range(Config.MAX_RETRIES):
|
||
try:
|
||
# 精确控制请求间隔
|
||
current_time = time.time()
|
||
time_since_last_request = current_time - self.last_request_time
|
||
|
||
# 确保请求间隔不小于最小限制
|
||
if time_since_last_request < min_interval:
|
||
time.sleep(min_interval - time_since_last_request)
|
||
|
||
result = func(*args, **kwargs)
|
||
self.last_request_time = time.time()
|
||
return result if result is not None and not (hasattr(result, 'empty') and result.empty) else None
|
||
except Exception as e:
|
||
wait_time = min(2 ** attempt, 5)
|
||
time.sleep(wait_time)
|
||
|
||
def save_to_txt(self, data: pd.DataFrame, filename: str) -> bool:
|
||
try:
|
||
# 按交易日期降序排序,确保最新交易日排在最前面
|
||
if 'trade_date' in data.columns:
|
||
data = data.sort_values('trade_date', ascending=False)
|
||
data.to_csv(filename, index=False, sep='\t', encoding='utf-8')
|
||
# logging.info(f"数据已保存到 {filename}")
|
||
return True
|
||
except Exception as e:
|
||
log_failure(f"保存文件时出错: {e}")
|
||
return False
|
||
|
||
def process_stock_code(self, code, progress_queue=None): # 修改参数默认值为None
|
||
pro = self.account_manager.get_next_account()
|
||
try:
|
||
output_file = os.path.join(Config.OUTPUT_DIR, f"{code}_daily_data.txt")
|
||
|
||
# 检查是否存在现有数据文件
|
||
if os.path.exists(output_file):
|
||
# 读取现有数据,获取最新的交易日期
|
||
existing_df = self.read_from_txt(output_file)
|
||
if existing_df is not None and not existing_df.empty:
|
||
# 获取最新交易日期
|
||
if 'trade_date' in existing_df.columns:
|
||
# 由于read_from_txt会将trade_date转换为datetime格式
|
||
# 确保现有数据的trade_date列是datetime格式
|
||
if not pd.api.types.is_datetime64_any_dtype(existing_df['trade_date']):
|
||
existing_df['trade_date'] = pd.to_datetime(existing_df['trade_date'], format='%Y%m%d')
|
||
|
||
# 获取最新交易日期
|
||
latest_date_dt = existing_df['trade_date'].max()
|
||
# 计算下一个交易日的起始日期(避免重复获取同一天数据)
|
||
next_date_dt = latest_date_dt + timedelta(days=1)
|
||
next_date = next_date_dt.strftime('%Y%m%d')
|
||
|
||
|
||
|
||
# 获取最新日期之后的数据
|
||
df = self.fetch_data_with_retry(pro.daily, ts_code=code, start_date=next_date)
|
||
|
||
if df is not None and not df.empty:
|
||
|
||
# 将新数据的trade_date列转换为datetime格式,以便合并
|
||
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')
|
||
|
||
# 合并现有数据和新数据
|
||
combined_df = pd.concat([existing_df, df], ignore_index=True)
|
||
|
||
# 去重,避免重复数据
|
||
combined_df = combined_df.drop_duplicates(subset=['trade_date', 'ts_code'], keep='last')
|
||
|
||
|
||
|
||
# 按交易日期降序排序,最新交易日排在最前面
|
||
combined_df = combined_df.sort_values('trade_date', ascending=False)
|
||
|
||
# 将trade_date转换回字符串格式保存
|
||
combined_df['trade_date'] = combined_df['trade_date'].dt.strftime('%Y%m%d')
|
||
|
||
# 保存合并后的数据
|
||
self.save_to_txt(combined_df, output_file)
|
||
else:
|
||
# 如果现有数据没有 trade_date 列,重新获取全部数据
|
||
df = self.fetch_data_with_retry(pro.daily, ts_code=code)
|
||
if df is not None:
|
||
self.save_to_txt(df, output_file)
|
||
else:
|
||
# 现有数据为空,重新获取全部数据
|
||
df = self.fetch_data_with_retry(pro.daily, ts_code=code)
|
||
if df is not None:
|
||
self.save_to_txt(df, output_file)
|
||
else:
|
||
# 文件不存在,获取全部数据
|
||
df = self.fetch_data_with_retry(pro.daily, ts_code=code)
|
||
if df is not None:
|
||
self.save_to_txt(df, output_file)
|
||
|
||
if progress_queue is not None: # 添加判断
|
||
progress_queue.put(1)
|
||
except (ConnectionError, TimeoutError) as e:
|
||
log_failure(f"股票 {code} 网络连接异常: {str(e)}")
|
||
except pd.errors.EmptyDataError:
|
||
log_warning(f"股票 {code} 返回空数据")
|
||
except (pymysql.err.OperationalError, pymysql.err.InterfaceError) as e:
|
||
log_failure(f"股票 {code} 数据库操作异常: {str(e)}")
|
||
except Exception as e:
|
||
log_failure(f"处理股票代码 {code} 时出错: {str(e)}")
|
||
finally:
|
||
if progress_queue is not None: # 添加判断
|
||
progress_queue.put(1)
|
||
|
||
def fetch_and_save_index_data(self, ts_code='000001.SH', progress_queue=None, completion_label=None):
|
||
pro = self.account_manager.get_next_account(force_primary=True) # 强制使用主账户
|
||
try:
|
||
df = pro.index_daily(ts_code=ts_code)
|
||
if df is None or df.empty:
|
||
log_warning(f"未获取到指数 {ts_code} 的日线数据")
|
||
return False
|
||
|
||
os.makedirs(Config.INDEX_DIR, exist_ok=True)
|
||
output_file = os.path.join(Config.INDEX_DIR, f"{ts_code}.txt")
|
||
self.save_to_txt(df, output_file)
|
||
if completion_label:
|
||
completion_label.config(text="指数数据下载完成!", foreground="green")
|
||
if progress_queue:
|
||
progress_queue.put(1)
|
||
return True
|
||
except Exception as e:
|
||
log_failure(f"Error fetching index data for {ts_code}: {e}")
|
||
if completion_label:
|
||
completion_label.config(text="获取指数数据失败!", foreground="red")
|
||
return False
|
||
|
||
def get_all_stock_codes(self, progress_queue=None, completion_label=None):
|
||
try:
|
||
# 使用缓存的股票基本信息
|
||
df_all = self.account_manager.get_stock_basic(force_primary=True)
|
||
df_all[['ts_code', 'name']].to_csv(Config.INPUT_FILE, index=False, header=False, sep='\t')
|
||
|
||
log_success(f'所有资料已保存到 {Config.INPUT_FILE}')
|
||
if completion_label:
|
||
completion_label.config(text="代码表更新完成!", foreground="green")
|
||
if progress_queue:
|
||
progress_queue.put(1)
|
||
except Exception as e:
|
||
logging.error(f'发生错误: {e}')
|
||
if completion_label:
|
||
completion_label.config(text="更新代码表失败!", foreground="red")
|
||
|
||
def fetch_and_save_stock_basic_info(self):
|
||
"""获取个股基础信息并保存到baseinfo目录"""
|
||
try:
|
||
log_process("开始获取个股基础信息...")
|
||
|
||
# 获取Tushare API实例
|
||
pro = self.account_manager.get_next_account(force_primary=True) # 强制使用主账户
|
||
|
||
# 获取股票基础信息
|
||
df = pro.stock_basic(
|
||
exchange='', # 交易所代码,不指定则查询所有
|
||
list_status='L', # 上市状态:L上市,D退市,P暂停上市
|
||
fields='ts_code,symbol,name,area,industry,fullname,enname,market,exchange,curr_type,list_status,list_date,delist_date,is_hs'
|
||
)
|
||
|
||
if df is None or df.empty:
|
||
log_failure("未获取到个股基础信息")
|
||
return False
|
||
|
||
# 创建baseinfo目录
|
||
baseinfo_dir = os.path.join(Config.BASE_DIR, 'baseinfo')
|
||
os.makedirs(baseinfo_dir, exist_ok=True)
|
||
|
||
# 保存数据到文件,使用TXT格式而不是CSV
|
||
output_file = os.path.join(baseinfo_dir, 'stock_basic_info.txt')
|
||
self.save_to_txt(df, output_file)
|
||
|
||
log_success(f"个股基础信息已成功保存到: {output_file}")
|
||
log_result(f"共获取到 {len(df)} 条个股基础信息")
|
||
return True
|
||
except Exception as e:
|
||
log_failure(f"获取个股基础信息失败: {str(e)}")
|
||
log_failure(f"获取个股基础信息失败: {str(e)}")
|
||
return False
|
||
|
||
def process_stock_codes_batch(self, codes: list):
|
||
"""批量处理股票代码(每次最多4000个)"""
|
||
pro = self.account_manager.get_next_account()
|
||
processed_count = 0
|
||
failed_count = 0
|
||
|
||
try:
|
||
# 分批处理,每批400个(考虑单次返回限制)
|
||
batch_size = 400
|
||
total_batches = (len(codes) + batch_size - 1) // batch_size
|
||
|
||
for batch_idx in range(0, len(codes), batch_size):
|
||
batch = codes[batch_idx:batch_idx + batch_size]
|
||
|
||
try:
|
||
# 获取批量数据
|
||
df = self.fetch_data_with_retry(pro.daily, ts_code=','.join(batch))
|
||
|
||
if df is not None and not df.empty:
|
||
# 按股票代码拆分数据并保存
|
||
for code in batch:
|
||
try:
|
||
# 筛选当前股票的数据
|
||
code_df = df[df['ts_code'] == code]
|
||
|
||
if not code_df.empty:
|
||
# 保存到文件
|
||
output_file = os.path.join(Config.OUTPUT_DIR, f"{code}_daily_data.txt")
|
||
if self.save_to_txt(code_df, output_file):
|
||
# 更新到数据库
|
||
if self.update_to_db(code_df, 'stock_daily'):
|
||
processed_count += 1
|
||
else:
|
||
failed_count += 1
|
||
log_warning(f"数据库更新失败,但文件已保存: {code}")
|
||
else:
|
||
failed_count += 1
|
||
log_failure(f"文件保存失败: {code}")
|
||
else:
|
||
log_warning(f"未获取到股票 {code} 的数据")
|
||
failed_count += 1
|
||
except Exception as code_e:
|
||
failed_count += 1
|
||
log_failure(f"处理股票 {code} 时出错: {str(code_e)}")
|
||
else:
|
||
failed_count += len(batch)
|
||
log_warning(f"批次 {batch_idx//batch_size + 1}/{total_batches} 返回空数据")
|
||
|
||
# 显示批次进度
|
||
log_result(f"批次 {batch_idx//batch_size + 1}/{total_batches} 处理完成 | 成功: {processed_count} | 失败: {failed_count}")
|
||
|
||
except (ConnectionError, TimeoutError) as e:
|
||
failed_count += len(batch)
|
||
log_failure(f"批次 {batch_idx//batch_size + 1} 网络异常: {str(e)}")
|
||
# 更换账户后重试
|
||
pro = self.account_manager.get_next_account()
|
||
except Exception as batch_e:
|
||
failed_count += len(batch)
|
||
log_failure(f"批次 {batch_idx//batch_size + 1} 处理异常: {str(batch_e)}")
|
||
|
||
# 记录总体处理结果
|
||
log_result(f"批量处理完成 | 总成功: {processed_count} | 总失败: {failed_count}")
|
||
return {'success': processed_count, 'failed': failed_count}
|
||
|
||
except Exception as e:
|
||
log_failure(f"批量处理主流程失败: {str(e)}")
|
||
return {'success': processed_count, 'failed': failed_count + (len(codes) - processed_count - failed_count)}
|
||
finally:
|
||
# 确保所有连接正确管理
|
||
# 添加返回连接方法的调用
|
||
try:
|
||
if hasattr(self, 'db_conn') and self.db_conn:
|
||
self.db_conn.close()
|
||
self.db_conn = None
|
||
log_message("数据库连接已关闭")
|
||
except Exception as e:
|
||
log_failure(f"关闭数据库连接时出错: {e}")
|
||
|
||
def read_from_txt(self, file_path):
|
||
"""从TXT文件读取行情数据 - 内存优化版本"""
|
||
try:
|
||
# 使用chunksize分块读取大文件
|
||
chunks = []
|
||
for chunk in pd.read_csv(file_path, sep='\t', encoding='utf-8', chunksize=10000):
|
||
# 对每一块进行预处理
|
||
chunk.columns = chunk.columns.str.lower()
|
||
if 'trade_date' in chunk.columns:
|
||
chunk['trade_date'] = pd.to_datetime(chunk['trade_date'], format='%Y%m%d')
|
||
chunks.append(chunk)
|
||
|
||
if chunks:
|
||
df = pd.concat(chunks, ignore_index=True)
|
||
return df
|
||
else:
|
||
return pd.DataFrame()
|
||
except Exception as e:
|
||
log_failure(f"读取文件 {file_path} 失败: {e}")
|
||
return None
|
||
|
||
# 添加新的方法用于优化数据库存储
|
||
def optimize_database_storage(self):
|
||
"""优化数据库存储结构"""
|
||
engine = None
|
||
try:
|
||
engine = create_engine(
|
||
f"mysql+pymysql://{Config.DB_CONFIG['user']}:{Config.DB_CONFIG['password']}@"
|
||
f"{Config.DB_CONFIG['host']}:{Config.DB_CONFIG['port']}/{Config.DB_CONFIG['database']}?charset=utf8mb4",
|
||
pool_pre_ping=True
|
||
)
|
||
|
||
with engine.begin() as connection:
|
||
# 添加索引以提高查询速度
|
||
connection.execute(text("ALTER TABLE stock_data_combined ADD INDEX IF NOT EXISTS idx_trade_date (trade_date)"))
|
||
connection.execute(text("ALTER TABLE stock_data_combined ADD INDEX IF NOT EXISTS idx_stock_code (stock_code)"))
|
||
# 如果有重复数据,添加去重逻辑
|
||
connection.execute(text("""
|
||
DELETE s1 FROM stock_data_combined s1
|
||
INNER JOIN stock_data_combined s2
|
||
WHERE s1.id > s2.id AND s1.trade_date = s2.trade_date AND s1.stock_code = s2.stock_code
|
||
"""))
|
||
|
||
log_message("数据库存储优化完成")
|
||
return True
|
||
except Exception as e:
|
||
log_failure(f"数据库存储优化失败: {e}")
|
||
return False
|
||
finally:
|
||
if engine:
|
||
try:
|
||
engine.dispose()
|
||
log_message(f"SQLAlchemy引擎已释放")
|
||
except Exception as e:
|
||
log_failure(f"释放SQLAlchemy引擎失败: {e}")
|
||
|
||
class ConsoleDataDownloader:
|
||
"""控制台模式数据下载器"""
|
||
|
||
def __init__(self):
|
||
self.downloader = DataDownloader()
|
||
|
||
def run_with_choice(self, choice):
|
||
"""使用指定的选项运行"""
|
||
try:
|
||
if choice == '0':
|
||
log_message("程序正在退出...")
|
||
sys.exit(0)
|
||
if choice == '6': # 新增全部工作逻辑
|
||
log_process("开始执行全部工作...")
|
||
start_time = time.time()
|
||
|
||
# 1. 更新股票代码
|
||
self.downloader.get_all_stock_codes()
|
||
|
||
# 2. 更新指数数据
|
||
self.downloader.fetch_and_save_index_data('000001.SH')
|
||
|
||
# 3. 更新个股数据
|
||
self.process_stock_codes()
|
||
|
||
# 4. 同步数据库
|
||
if not self.downloader.connect_db():
|
||
log_failure("数据库连接失败,跳过数据库同步")
|
||
else:
|
||
self.downloader.update_database()
|
||
|
||
log_success(f"全部工作完成! 总耗时: {time.time() - start_time:.2f}秒")
|
||
return
|
||
# 检查数据库连接是否正常(仅当选择数据库相关操作时)
|
||
if choice == '5' and not self.downloader.connect_db():
|
||
log_failure("数据库连接失败,请检查配置后重试")
|
||
return
|
||
|
||
# 处理选项7:拷贝数据
|
||
if choice == '7':
|
||
self.copy_data_to_target_directory()
|
||
return
|
||
|
||
# 处理选项8:检查数据完整性
|
||
if choice == '8':
|
||
self.check_data_integrity()
|
||
return
|
||
|
||
# 处理选项9:获取个股基础信息
|
||
if choice == '9':
|
||
self.downloader.fetch_and_save_stock_basic_info()
|
||
return
|
||
|
||
tasks = []
|
||
if choice in ('1', '4'):
|
||
tasks.append(self.downloader.get_all_stock_codes)
|
||
if choice in ('2', '4'):
|
||
tasks.append(lambda: self.downloader.fetch_and_save_index_data('000001.SH'))
|
||
if choice in ('3', '4'):
|
||
tasks.append(self.process_stock_codes)
|
||
if choice == '5': # 新增数据库更新任务
|
||
tasks.append(self.downloader.update_database)
|
||
|
||
if not tasks:
|
||
log_warning("无效选项")
|
||
return
|
||
|
||
log_process("开始执行任务...")
|
||
start_time = time.time()
|
||
for task in tasks:
|
||
task()
|
||
log_success(f"任务完成! 总耗时: {time.time() - start_time:.2f}秒")
|
||
except Exception as e:
|
||
log_failure(f"执行任务时发生异常: {e}")
|
||
log_failure(f"执行任务时发生异常: {e}")
|
||
|
||
def run(self):
|
||
try:
|
||
print("请选择要执行的操作:")
|
||
print("1. 更新股票代码表")
|
||
print("2. 更新指数数据")
|
||
print("3. 更新个股数据")
|
||
print("4. 全部更新")
|
||
print("5. 同步数据库") # 新增选项
|
||
print("6. 全部工作(下载+同步数据库)") # 新增选项
|
||
print("7. 拷贝更新后的数据到目标目录") # 新增选项
|
||
print("8. 检查数据完整性") # 新增选项
|
||
print("9. 获取个股基础信息") # 新增选项
|
||
print("0. 退出")
|
||
|
||
while True:
|
||
try:
|
||
choice = input("请输入选项: ").strip()
|
||
self.run_with_choice(choice)
|
||
except KeyboardInterrupt:
|
||
print("\n程序已中断")
|
||
return
|
||
except Exception as e:
|
||
log_failure(f"用户交互过程中发生异常: {e}")
|
||
log_failure(f"操作异常: {e}")
|
||
if choice == '6': # 新增全部工作逻辑
|
||
log_process("开始执行全部工作...")
|
||
start_time = time.time()
|
||
|
||
# 1. 更新股票代码
|
||
self.downloader.get_all_stock_codes()
|
||
|
||
# 2. 更新指数数据
|
||
self.downloader.fetch_and_save_index_data('000001.SH')
|
||
|
||
# 3. 更新个股数据
|
||
self.process_stock_codes()
|
||
|
||
# 4. 同步数据库
|
||
if not self.downloader.connect_db():
|
||
log_failure("数据库连接失败,跳过数据库同步")
|
||
else:
|
||
self.downloader.update_database()
|
||
|
||
log_success(f"全部工作完成! 总耗时: {time.time() - start_time:.2f}秒")
|
||
continue
|
||
# 检查数据库连接是否正常(仅当选择数据库相关操作时)
|
||
# 检查数据库连接是否正常(仅当选择数据库相关操作时)
|
||
if choice == '5' and not self.downloader.connect_db():
|
||
log_failure("数据库连接失败,请检查配置后重试")
|
||
continue
|
||
|
||
# 处理选项7:拷贝数据
|
||
if choice == '7':
|
||
self.copy_data_to_target_directory()
|
||
continue
|
||
|
||
tasks = []
|
||
if choice in ('1', '4'):
|
||
tasks.append(self.downloader.get_all_stock_codes)
|
||
if choice in ('2', '4'):
|
||
tasks.append(lambda: self.downloader.fetch_and_save_index_data('000001.SH'))
|
||
if choice in ('3', '4'):
|
||
tasks.append(self.process_stock_codes)
|
||
if choice == '5': # 新增数据库更新任务
|
||
tasks.append(self.downloader.update_database)
|
||
|
||
if not tasks:
|
||
log_warning("无效选项,请重新输入")
|
||
continue
|
||
|
||
log_process("开始执行任务...")
|
||
start_time = time.time()
|
||
|
||
completed = 0
|
||
failed = 0
|
||
for task in tasks:
|
||
try:
|
||
task()
|
||
completed += 1
|
||
except (ConnectionError, TimeoutError) as e:
|
||
failed += 1
|
||
log_failure(f"任务执行网络异常: {str(e)}")
|
||
except Exception as e:
|
||
failed += 1
|
||
log_failure(f"任务执行异常: {str(e)}")
|
||
|
||
log_success(f"任务完成! 成功: {completed}, 失败: {failed}, 总耗时: {time.time()-start_time:.2f}秒")
|
||
except (ConnectionError, TimeoutError) as e:
|
||
log_failure(f"用户交互过程中发生网络异常: {str(e)}")
|
||
log_failure(f"网络异常: {str(e)}")
|
||
except Exception as e:
|
||
log_failure(f"用户交互过程中发生异常: {str(e)}")
|
||
log_failure(f"操作异常: {str(e)}")
|
||
except KeyboardInterrupt:
|
||
log_warning("用户中断程序")
|
||
print("程序已中断")
|
||
except Exception as e:
|
||
log_failure(f"程序运行异常: {str(e)}")
|
||
|
||
def copy_data_to_target_directory(self):
|
||
"""拷贝D:\gp_data\day目录下的数据到D:\gp_data\floor_cl\data目录"""
|
||
import os
|
||
import shutil
|
||
import time
|
||
|
||
# 使用Config.BASE_DIR作为基础路径,保持代码一致性
|
||
source_dir = os.path.join(Config.BASE_DIR, 'day')
|
||
target_dir = 'D:\\data\\jq_hc\\data'
|
||
|
||
log_process(f"准备从 {source_dir} 拷贝数据到 {target_dir}")
|
||
|
||
try:
|
||
# 确保目标目录存在
|
||
if not os.path.exists(target_dir):
|
||
os.makedirs(target_dir, exist_ok=True)
|
||
log_message(f"已创建目标目录: {target_dir}")
|
||
|
||
# 检查源目录是否存在
|
||
if not os.path.exists(source_dir):
|
||
log_failure(f"源目录 {source_dir} 不存在,请检查路径")
|
||
return
|
||
|
||
# 获取源目录中的txt文件列表
|
||
files = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f)) and f.lower().endswith('.txt')]
|
||
total_files = len(files)
|
||
|
||
if total_files == 0:
|
||
log_warning("源目录中没有找到文件")
|
||
return
|
||
|
||
log_process(f"开始拷贝 {total_files} 个文件...")
|
||
start_time = time.time()
|
||
copied_count = 0
|
||
failed_count = 0
|
||
|
||
for i, file_name in enumerate(files, 1):
|
||
source_path = os.path.join(source_dir, file_name)
|
||
target_path = os.path.join(target_dir, file_name)
|
||
|
||
try:
|
||
# 拷贝文件并覆盖已存在的文件
|
||
shutil.copy2(source_path, target_path)
|
||
copied_count += 1
|
||
|
||
# 显示进度
|
||
progress = (i / total_files) * 100
|
||
elapsed = time.time() - start_time
|
||
log_process(f"\r进度: [{('#' * int(progress / 2)):50}] {progress:.1f}% | {file_name}", end='', flush=True)
|
||
|
||
except Exception as e:
|
||
failed_count += 1
|
||
log_failure(f"\n拷贝文件 {file_name} 失败: {str(e)}")
|
||
|
||
# 打印完成信息
|
||
log_success(f"\n拷贝完成!")
|
||
log_result(f"成功: {copied_count}, 失败: {failed_count}")
|
||
log_result(f"总耗时: {time.time() - start_time:.2f}秒")
|
||
|
||
except KeyboardInterrupt:
|
||
log_warning("\n拷贝操作已被用户中断")
|
||
except Exception as e:
|
||
log_failure(f"\n拷贝过程中发生异常: {str(e)}")
|
||
|
||
def check_data_integrity(self):
|
||
"""检查数据完整性,运行check_market_data.py程序"""
|
||
import subprocess
|
||
import os
|
||
|
||
try:
|
||
log_process("开始检查数据完整性...")
|
||
|
||
# 构建check_market_data.py的完整路径
|
||
script_path = os.path.join(os.path.dirname(__file__), 'check_market_data.py')
|
||
|
||
# 运行check_market_data.py程序,直接显示输出
|
||
result = subprocess.run(
|
||
[sys.executable, script_path],
|
||
stdout=None, # 不捕获标准输出,直接显示在控制台上
|
||
stderr=None, # 不捕获错误输出,直接显示在控制台上
|
||
encoding='utf-8'
|
||
)
|
||
|
||
if result.returncode == 0:
|
||
log_success(f"\n数据完整性检查完成!")
|
||
else:
|
||
log_failure(f"\n数据完整性检查失败,返回码: {result.returncode}")
|
||
|
||
except FileNotFoundError:
|
||
log_failure(f"\n未找到check_market_data.py文件,请检查路径: {script_path}")
|
||
except Exception as e:
|
||
log_failure(f"\n执行数据完整性检查时发生异常: {e}")
|
||
|
||
def process_stock_codes(self):
|
||
"""处理所有股票代码 - 优化版本"""
|
||
try:
|
||
with open(Config.INPUT_FILE, 'r', encoding='utf-8') as f:
|
||
stock_codes = [line.strip().split('\t')[0] for line in f if line.strip()]
|
||
except Exception as e:
|
||
log_failure(f"读取股票代码文件失败: {e}")
|
||
return
|
||
|
||
total = len(stock_codes)
|
||
completed = 0
|
||
start_time = time.time()
|
||
|
||
def print_progress():
|
||
progress = (completed / total) * 100
|
||
elapsed = time.time() - start_time
|
||
# 单行进度显示
|
||
log_process(
|
||
f"\r进度: [{'#' * int(progress / 2)}{' ' * (50 - int(progress / 2))}] {progress:.1f}% | 已完成: {completed}/{total} | 耗时: {elapsed:.1f}s",
|
||
end='', flush=True)
|
||
|
||
# 按批次处理,避免一次性创建过多线程
|
||
batch_size = Config.MAX_THREADS * 2
|
||
for i in range(0, total, batch_size):
|
||
batch = stock_codes[i:i + batch_size]
|
||
|
||
with ThreadPoolExecutor(max_workers=Config.MAX_THREADS) as executor:
|
||
futures = {
|
||
executor.submit(self.downloader.process_stock_code, code, None): code
|
||
for code in batch
|
||
}
|
||
|
||
for future in as_completed(futures):
|
||
completed += 1
|
||
if completed % 5 == 0 or completed == total:
|
||
print_progress()
|
||
future.result()
|
||
|
||
log_success(f"\r进度: [{('#' * 50)}] 100.0% | 已完成: {total}/{total} | 总耗时: {time.time() - start_time:.1f}s")
|
||
log_success(f"\n个股数据下载完成!")
|
||
|
||
# 创建全局AccountManager实例
|
||
_global_account_manager = AccountManager()
|
||
|
||
@lru_cache(maxsize=1000)
|
||
def get_cached_data(code):
|
||
"""缓存股票数据查询结果"""
|
||
pro = _global_account_manager.get_next_account() # 使用全局AccountManager实例
|
||
return pro.daily(ts_code=code)
|
||
|
||
def parse_args():
|
||
"""解析命令行参数"""
|
||
parser = argparse.ArgumentParser(description='Tushare数据更新工具')
|
||
parser.add_argument('-c', '--choice', type=str, choices=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
||
help='要执行的操作: 0=退出, 1=更新股票代码表, 2=更新指数数据, 3=更新个股数据, 4=全部更新, 5=同步数据库, 6=全部工作, 7=拷贝数据到目标目录, 8=检查数据完整性, 9=获取个股基础信息')
|
||
return parser.parse_args()
|
||
|
||
if __name__ == "__main__":
|
||
args = parse_args()
|
||
downloader = ConsoleDataDownloader()
|
||
if args.choice:
|
||
# 使用命令行参数直接执行操作
|
||
downloader.run_with_choice(args.choice)
|
||
else:
|
||
# 进入交互式模式
|
||
downloader.run()
|