from threading import Thread import os import time import logging import threading import socket import sys from datetime import timedelta from concurrent.futures import ThreadPoolExecutor, as_completed import pandas as pd import tushare as ts from log_style import setup_logger, get_logger from functools import lru_cache import argparse # 配置日志记录器 logger = setup_logger( name='update_tushare_totxt', log_file='app.log', level=logging.INFO, log_format='%(asctime)s - %(levelname)s - %(message)s', console=True ) def log_warning(message, **kwargs): """警告信息,黄色显示""" logger.warning(f"\033[33m{message}\033[0m", **kwargs) def log_failure(message, **kwargs): """失败信息,红色显示""" logger.error(f"\033[31m{message}\033[0m", **kwargs) # 定义带颜色的日志函数 def log_success(message, **kwargs): """成功信息,绿色显示""" logger.info(f"\033[32m{message}\033[0m", **kwargs) def log_failure(message, **kwargs): """失败信息,红色显示""" logger.error(f"\033[31m{message}\033[0m", **kwargs) def log_process(message, **kwargs): """处理中信息,黄色显示""" logger.info(f"\033[33m{message}\033[0m", **kwargs) def log_result(message, **kwargs): """结果信息,蓝色显示""" logger.info(f"\033[34m{message}\033[0m", **kwargs) def log_message(message, **kwargs): """普通信息,白色显示""" logger.info(f"\033[37m{message}\033[0m", **kwargs) # 数据库相关导入 import pymysql from sqlalchemy import create_engine, text class Config: """全局配置类""" BASE_DIR = 'D:\\gp_data' INPUT_FILE = os.path.join(BASE_DIR, 'code', 'all_stock_codes.txt') OUTPUT_DIR = os.path.join(BASE_DIR, 'day') INDEX_DIR = os.path.join(BASE_DIR, 'index') # 最大线程数,可通过环境变量MAX_THREADS设置,默认10 MAX_THREADS = int(os.getenv('MAX_THREADS', 20)) # 每分钟最大请求次数限制,可通过环境变量REQUEST_LIMIT设置,默认200 REQUEST_LIMIT = int(os.getenv('REQUEST_LIMIT', 500)) # 请求失败时的最大重试次数 MAX_RETRIES = 3 # 时间常量,表示60秒(1分钟),用于请求频率控制 ONE_MINUTE = 60 # 多账户Token数组,从环境变量读取,如果没有则使用默认值 ACCOUNTS = [ os.getenv('TUSHARE_TOKEN1', '9343e641869058684afeadfcfe7fd6684160852e52e85332a7734c8d'), # 主账户 os.getenv('TUSHARE_TOKEN2', 'b156282fb2328afa056346736a93778861ac6be97a485f4d1f55c493') # 备用账户1 ] # 新增数据库重试配置 DB_MAX_RETRIES = 3 # 最大重试次数 DB_RETRY_INTERVAL = 3 # 重试间隔(秒) # 新增数据库配置,从环境变量读取敏感信息 DB_CONFIG = { 'host': os.getenv('DB_HOST', '127.0.0.1'), 'port': int(os.getenv('DB_PORT', 3306)), 'user': os.getenv('DB_USER', 'root'), 'password': os.getenv('DB_PASSWORD', 'stella0850'), 'database': os.getenv('DB_NAME', 'stock_data') } # 数据库连接池配置 DB_POOL_SIZE = 20 # 连接池大小 DB_MAX_OVERFLOW = 10 # 最大溢出连接数 class AccountManager: # 类变量缓存,实现跨实例的缓存共享 _stock_basic_cache = None _stock_basic_cache_time = 0 def __init__(self): """初始化账户管理器""" self.accounts = [ts.pro_api(token) for token in Config.ACCOUNTS] self.current_index = 0 self.lock = threading.Lock() # 为每个账户单独维护请求时间,提高并发效率 self.last_request_time_per_account = [time.time() for _ in self.accounts] def get_stock_basic(self, force_primary=False): """获取股票基本信息,带缓存机制""" current_time = time.time() # 使用类变量缓存,实现跨实例的缓存共享 # 缓存有效期为30分钟 if (AccountManager._stock_basic_cache is not None and current_time - AccountManager._stock_basic_cache_time < 30 * 60): return AccountManager._stock_basic_cache # 缓存失效,重新获取 pro = self.get_next_account(force_primary) df_sse = pro.stock_basic(exchange='SSE', list_status='L') df_szse = pro.stock_basic(exchange='SZSE', list_status='L') AccountManager._stock_basic_cache = pd.concat([df_sse, df_szse], ignore_index=True) AccountManager._stock_basic_cache_time = current_time return AccountManager._stock_basic_cache def get_next_account(self, force_primary=False): """轮询获取下一个账户 force_primary: 强制使用主账户 """ with self.lock: try: if force_primary: return self.accounts[0] # 强制返回主账户 account = self.accounts[self.current_index] self.current_index = (self.current_index + 1) % len(self.accounts) return account except (ConnectionError, TimeoutError) as e: # 网络连接相关异常 log_failure(f"账户网络连接异常: {str(e)}") except Exception as e: # 其他异常 log_failure(f"获取账户异常: {str(e)}") # 无论什么异常,都尝试下一个账户 self.current_index = (self.current_index + 1) % len(self.accounts) # 防止无限递归,设置最大尝试次数 retry_count = getattr(self, '_retry_count', 0) if retry_count >= len(self.accounts) * 2: # 尝试每个账户2次 setattr(self, '_retry_count', 0) log_failure("所有账户都尝试失败,返回最后一个账户作为尝试") return self.accounts[0] # 返回第一个账户作为最后的尝试 setattr(self, '_retry_count', retry_count + 1) return self.get_next_account(force_primary) class DataDownloader: """数据下载器核心类""" def __init__(self): self.account_manager = AccountManager() self.setup_logging() self.create_dirs() self.last_request_time = time.time() self.db_conn = None # 新增数据库连接属性 self.total_files = 0 self.processed_files = 0 self.progress = {'current': 0, 'total': 100, 'message': '准备开始'} # 数据库连接池管理 self.conn_pool = {} self.conn_count = 0 self.max_connections = 5 # 最大连接数限制 self.active_connections = set() def show_progress(self, current, total, start_time): """进度显示""" self.progress = { 'current': current, 'total': total, 'message': f'处理中: {current / total * 100:.1f}% | 耗时: {time.time() - start_time:.1f}s' } progress = current / total * 100 elapsed = time.time() - start_time log_process(f"\r处理中: {progress:.1f}% ({current}/{total}) | 耗时: {elapsed:.1f}s", end='', flush=True) if current == total: log_success(f"\n数据处理完成!") # 添加换行符和完成信息 # def show_progress(self, current, total, start_time): # """更频繁的进度显示""" # if current == total or current % 2 == 0: # progress = current / total * 100 # elapsed = time.time() - start_time # print(f"\r处理中: {progress:.1f}% ({current}/{total}) | 耗时: {elapsed:.1f}s", end='') # if current == total: # print(" 完成!") def setup_logging(self): logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler('app.log') ] ) def connect_db(self): """连接数据库,带重试机制和自动创建数据库功能""" # 检查连接池是否有可用连接 if self.conn_pool: # 使用第一个可用连接 conn_id, conn = next(iter(self.conn_pool.items())) try: # 测试连接是否有效 cursor = conn.cursor() cursor.execute("SELECT 1") cursor.close() del self.conn_pool[conn_id] self.active_connections.add(conn_id) log_message(f"复用数据库连接 {conn_id}") self.db_conn = conn return True except Exception: # 连接无效,移除并继续创建新连接 try: conn.close() except: pass del self.conn_pool[conn_id] log_warning(f"移除无效连接 {conn_id}") def return_connection(self, conn=None): """返回连接到连接池或关闭连接""" # 如果提供了连接,使用提供的连接 if conn: try: # 测试连接是否仍然有效 cursor = conn.cursor() cursor.execute("SELECT 1") cursor.close() # 生成唯一连接ID conn_id = f"conn_{int(time.time() * 1000)}" # 从活动连接中移除并添加到连接池 if hasattr(self, 'active_connections'): for active_conn_id in list(self.active_connections): if active_conn_id.endswith(str(hash(str(conn)))): self.active_connections.remove(active_conn_id) # 添加到连接池 self.conn_pool[conn_id] = conn log_message(f"连接已返回连接池: {conn_id}") return True except Exception: # 连接已无效,关闭它 try: conn.close() if hasattr(self, 'conn_count'): self.conn_count = max(0, self.conn_count - 1) log_message(f"无效连接已关闭,当前连接数: {self.conn_count}") except: pass return False # 如果使用的是对象的默认连接 elif hasattr(self, 'db_conn') and self.db_conn: try: # 测试连接是否仍然有效 cursor = self.db_conn.cursor() cursor.execute("SELECT 1") cursor.close() # 生成唯一连接ID conn_id = f"conn_{int(time.time() * 1000)}" # 从活动连接中移除并添加到连接池 for active_conn_id in list(self.active_connections): self.active_connections.remove(active_conn_id) # 添加到连接池 self.conn_pool[conn_id] = self.db_conn self.db_conn = None # 清除对象属性 log_message(f"默认连接已返回连接池: {conn_id}") return True except Exception: # 连接已无效,关闭它 try: self.db_conn.close() self.db_conn = None self.conn_count = max(0, self.conn_count - 1) log_message(f"默认无效连接已关闭,当前连接数: {self.conn_count}") except: pass return False return False # 如果连接数超过限制,等待一段时间 while self.conn_count >= self.max_connections: log_warning(f"数据库连接数达到上限 {self.max_connections},等待释放...") time.sleep(1) for attempt in range(Config.DB_MAX_RETRIES): try: # 先尝试连接指定数据库 conn = pymysql.connect(**Config.DB_CONFIG) conn_id = f"conn_{int(time.time() * 1000)}" self.conn_count += 1 self.active_connections.add(conn_id) log_message(f"创建数据库连接 {conn_id},当前连接数: {self.conn_count}") self.db_conn = conn return True # 静默返回,不记录成功日志 except pymysql.err.OperationalError as e: if e.args[0] == 1049: # 数据库不存在错误码 try: # 连接MySQL服务器(不带数据库名) temp_config = Config.DB_CONFIG.copy() temp_config.pop('database') temp_conn = pymysql.connect(**temp_config) # 创建数据库(使用Config.DB_CONFIG中的database名) with temp_conn.cursor() as cursor: cursor.execute(f"CREATE DATABASE IF NOT EXISTS {Config.DB_CONFIG['database']} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") temp_conn.commit() temp_conn.close() log_message(f"数据库 {Config.DB_CONFIG['database']} 创建成功") # 重新尝试连接 continue except Exception as create_e: log_failure(f"创建数据库失败: {create_e}") return False elif attempt < Config.DB_MAX_RETRIES - 1: # 不是最后一次尝试 log_warning(f"数据库连接失败(尝试 {attempt + 1}/{Config.DB_MAX_RETRIES}): {e}") time.sleep(Config.DB_RETRY_INTERVAL) else: log_failure(f"数据库连接失败,已达到最大重试次数 {Config.DB_MAX_RETRIES}") return False except Exception as e: if attempt < Config.DB_MAX_RETRIES - 1: # 不是最后一次尝试 log_warning(f"数据库连接失败(尝试 {attempt + 1}/{Config.DB_MAX_RETRIES}): {e}") time.sleep(Config.DB_RETRY_INTERVAL) else: log_failure(f"数据库连接失败,已达到最大重试次数 {Config.DB_MAX_RETRIES}") return False def update_to_db(self, data: pd.DataFrame, table_name: str): """线程安全的数据库更新方法""" engine = None try: # 每个线程创建自己的引擎和连接,使用更好的连接池配置 engine = create_engine( f"mysql+pymysql://{Config.DB_CONFIG['user']}:{Config.DB_CONFIG['password']}@" f"{Config.DB_CONFIG['host']}:{Config.DB_CONFIG['port']}/{Config.DB_CONFIG['database']}?charset=utf8mb4", pool_size=Config.MAX_THREADS, # 设置连接池大小 max_overflow=Config.DB_MAX_OVERFLOW, pool_timeout=30, # 连接超时时间 pool_recycle=3600, # 连接回收时间 pool_pre_ping=True, # 连接前检查有效性 echo=False, connect_args={ "charset": "utf8mb4", "autocommit": True, # 自动提交 "connect_timeout": 30, "read_timeout": 30, "write_timeout": 30 } ) # 强制表名小写并移除特殊字符 table_name = table_name.lower().replace('.', '_') # 添加表结构检查 if table_name.startswith('stock_'): if 'trade_date' not in data.columns: log_failure(f"数据缺少必要列 trade_date") return False if not pd.api.types.is_datetime64_any_dtype(data['trade_date']): data['trade_date'] = pd.to_datetime(data['trade_date'], format='%Y%m%d') with engine.begin() as connection: # 使用事务 data.to_sql(table_name, connection, if_exists='append', index=False, chunksize=1000) return True except Exception as e: log_failure(f"数据库更新失败: {e}") return False finally: # 确保资源释放 if engine: try: engine.dispose() log_message(f"SQLAlchemy引擎已释放") except Exception as e: log_failure(f"释放SQLAlchemy引擎失败: {e}") def update_database(self, progress_queue=None, completion_label=None): try: log_process("开始更新数据库...") start_time = time.time() # 1. 更新股票代码表(使用缓存机制) log_process("正在更新股票代码表...") df_all = self.account_manager.get_stock_basic(force_primary=True) self.update_to_db(df_all, 'stock_basic') # 2. 从本地文件更新日线数据 log_process("正在更新个股日线数据...") files = [f for f in os.listdir(Config.OUTPUT_DIR) if f.endswith('_daily_data.txt')] self.total_files = len(files) self.processed_files = 0 if self.total_files > 0: # 使用批量插入优化数据库写入 self.batch_update_stock_data(files, start_time) else: log_message("没有找到需要处理的个股数据文件") # 初始显示进度0% # self.show_progress(0, self.total_files, start_time) # # # 使用批量插入优化数据库写入 # self.batch_update_stock_data(files, start_time) # for i, file in enumerate(files, 1): # file_path = os.path.join(Config.OUTPUT_DIR, file) # df = self.read_from_txt(file_path) # if df is not None: # code = file.split('_')[0].lower().replace('.', '_') # self.update_to_db(df, f'stock_{code}') # # self.processed_files = i # self.show_progress(i, self.total_files, start_time) # 3. 更新指数数据(从本地文件) log_process("正在更新指数数据...") index_file = os.path.join(Config.INDEX_DIR, '000001.SH.txt') if os.path.exists(index_file): df_index = self.read_from_txt(index_file) if df_index is not None: self.update_to_db(df_index, 'index_daily') # 4. 更新个股数据(分批处理) log_process("正在更新实时个股数据...") stock_codes = df_all['ts_code'].tolist() batch_size = 400 for i in range(0, len(stock_codes), batch_size): batch = stock_codes[i:i + batch_size] df = self.fetch_data_with_retry(pro.daily, ts_code=','.join(batch)) if df is not None: self.update_to_db(df, 'stock_daily') log_success(f"数据库更新完成! 总耗时: {time.time() - start_time:.2f}秒") if completion_label: completion_label.config(text="数据库更新完成!", foreground="green") if progress_queue: progress_queue.put(1) return True except Exception as e: log_failure(f"数据库更新失败: {e}") if completion_label: completion_label.config(text="数据库更新失败!", foreground="red") return False def batch_update_stock_data(self, files, start_time): """批量更新股票数据到数据库以提高性能 - 优化版本""" engine = None try: # 创建一个共享的数据库引擎,使用更好的连接池配置 engine = create_engine( f"mysql+pymysql://{Config.DB_CONFIG['user']}:{Config.DB_CONFIG['password']}@" f"{Config.DB_CONFIG['host']}:{Config.DB_CONFIG['port']}/{Config.DB_CONFIG['database']}?charset=utf8mb4", pool_size=Config.MAX_THREADS, max_overflow=Config.DB_MAX_OVERFLOW, pool_timeout=30, pool_recycle=3600, pool_pre_ping=True, echo=False # 关闭SQL日志以提高性能 ) # 批量处理文件,每批处理100个文件 batch_size = 100 processed_count = 0 total_files = len(files) for i in range(0, len(files), batch_size): batch_files = files[i:i + batch_size] all_data = [] # 收集一批数据 for file in batch_files: file_path = os.path.join(Config.OUTPUT_DIR, file) df = self.read_from_txt(file_path) if df is not None: code = file.split('_')[0].lower().replace('.', '_') df['stock_code'] = code # 添加股票代码标识 all_data.append(df) # 更新单个文件处理进度 processed_count += 1 self.show_progress(processed_count, total_files, start_time) # 合并数据并批量插入 if all_data: combined_df = pd.concat(all_data, ignore_index=True) # 使用事务批量插入 with engine.begin() as connection: # 先删除重复数据(如果有需要) # 然后批量插入新数据 combined_df.to_sql('stock_data_combined', connection, if_exists='append', index=False, chunksize=2000, method='multi') # 确保最终进度显示为100% if total_files > 0: self.show_progress(total_files, total_files, start_time) except Exception as e: log_failure(f"批量更新股票数据失败: {e}") finally: # 确保资源释放 if engine: try: engine.dispose() log_message(f"批量更新的SQLAlchemy引擎已释放") except Exception as e: log_failure(f"释放批量更新的SQLAlchemy引擎失败: {e}") def create_dirs(self): # 创建基础目录 os.makedirs(Config.BASE_DIR, exist_ok=True) # 创建代码文件目录 os.makedirs(os.path.dirname(Config.INPUT_FILE), exist_ok=True) # 创建输出目录 os.makedirs(Config.OUTPUT_DIR, exist_ok=True) # 创建指数目录 os.makedirs(Config.INDEX_DIR, exist_ok=True) # 优化 fetch_data_with_retry 方法,改进请求频率控制 def fetch_data_with_retry(self, func, *args, **kwargs): """带重试机制的数据获取方法 - 优化版本""" # 计算请求频率限制(总请求数/分钟) total_requests_per_minute = Config.REQUEST_LIMIT * len(Config.ACCOUNTS) requests_per_second = total_requests_per_minute / 60.0 min_interval = 1.0 / requests_per_second for attempt in range(Config.MAX_RETRIES): try: # 精确控制请求间隔 current_time = time.time() time_since_last_request = current_time - self.last_request_time # 确保请求间隔不小于最小限制 if time_since_last_request < min_interval: time.sleep(min_interval - time_since_last_request) result = func(*args, **kwargs) self.last_request_time = time.time() return result if result is not None and not (hasattr(result, 'empty') and result.empty) else None except Exception as e: wait_time = min(2 ** attempt, 5) time.sleep(wait_time) def save_to_txt(self, data: pd.DataFrame, filename: str) -> bool: try: # 按交易日期降序排序,确保最新交易日排在最前面 if 'trade_date' in data.columns: data = data.sort_values('trade_date', ascending=False) data.to_csv(filename, index=False, sep='\t', encoding='utf-8') # logging.info(f"数据已保存到 {filename}") return True except Exception as e: log_failure(f"保存文件时出错: {e}") return False def process_stock_code(self, code, progress_queue=None): # 修改参数默认值为None pro = self.account_manager.get_next_account() try: output_file = os.path.join(Config.OUTPUT_DIR, f"{code}_daily_data.txt") # 检查是否存在现有数据文件 if os.path.exists(output_file): # 读取现有数据,获取最新的交易日期 existing_df = self.read_from_txt(output_file) if existing_df is not None and not existing_df.empty: # 获取最新交易日期 if 'trade_date' in existing_df.columns: # 由于read_from_txt会将trade_date转换为datetime格式 # 确保现有数据的trade_date列是datetime格式 if not pd.api.types.is_datetime64_any_dtype(existing_df['trade_date']): existing_df['trade_date'] = pd.to_datetime(existing_df['trade_date'], format='%Y%m%d') # 获取最新交易日期 latest_date_dt = existing_df['trade_date'].max() # 计算下一个交易日的起始日期(避免重复获取同一天数据) next_date_dt = latest_date_dt + timedelta(days=1) next_date = next_date_dt.strftime('%Y%m%d') # 获取最新日期之后的数据 df = self.fetch_data_with_retry(pro.daily, ts_code=code, start_date=next_date) if df is not None and not df.empty: # 将新数据的trade_date列转换为datetime格式,以便合并 df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d') # 合并现有数据和新数据 combined_df = pd.concat([existing_df, df], ignore_index=True) # 去重,避免重复数据 combined_df = combined_df.drop_duplicates(subset=['trade_date', 'ts_code'], keep='last') # 按交易日期降序排序,最新交易日排在最前面 combined_df = combined_df.sort_values('trade_date', ascending=False) # 将trade_date转换回字符串格式保存 combined_df['trade_date'] = combined_df['trade_date'].dt.strftime('%Y%m%d') # 保存合并后的数据 self.save_to_txt(combined_df, output_file) else: # 如果现有数据没有 trade_date 列,重新获取全部数据 df = self.fetch_data_with_retry(pro.daily, ts_code=code) if df is not None: self.save_to_txt(df, output_file) else: # 现有数据为空,重新获取全部数据 df = self.fetch_data_with_retry(pro.daily, ts_code=code) if df is not None: self.save_to_txt(df, output_file) else: # 文件不存在,获取全部数据 df = self.fetch_data_with_retry(pro.daily, ts_code=code) if df is not None: self.save_to_txt(df, output_file) if progress_queue is not None: # 添加判断 progress_queue.put(1) except (ConnectionError, TimeoutError) as e: log_failure(f"股票 {code} 网络连接异常: {str(e)}") except pd.errors.EmptyDataError: log_warning(f"股票 {code} 返回空数据") except (pymysql.err.OperationalError, pymysql.err.InterfaceError) as e: log_failure(f"股票 {code} 数据库操作异常: {str(e)}") except Exception as e: log_failure(f"处理股票代码 {code} 时出错: {str(e)}") finally: if progress_queue is not None: # 添加判断 progress_queue.put(1) def fetch_and_save_index_data(self, ts_code='000001.SH', progress_queue=None, completion_label=None): pro = self.account_manager.get_next_account(force_primary=True) # 强制使用主账户 try: df = pro.index_daily(ts_code=ts_code) if df is None or df.empty: log_warning(f"未获取到指数 {ts_code} 的日线数据") return False os.makedirs(Config.INDEX_DIR, exist_ok=True) output_file = os.path.join(Config.INDEX_DIR, f"{ts_code}.txt") self.save_to_txt(df, output_file) if completion_label: completion_label.config(text="指数数据下载完成!", foreground="green") if progress_queue: progress_queue.put(1) return True except Exception as e: log_failure(f"Error fetching index data for {ts_code}: {e}") if completion_label: completion_label.config(text="获取指数数据失败!", foreground="red") return False def get_all_stock_codes(self, progress_queue=None, completion_label=None): try: # 使用缓存的股票基本信息 df_all = self.account_manager.get_stock_basic(force_primary=True) df_all[['ts_code', 'name']].to_csv(Config.INPUT_FILE, index=False, header=False, sep='\t') log_success(f'所有资料已保存到 {Config.INPUT_FILE}') if completion_label: completion_label.config(text="代码表更新完成!", foreground="green") if progress_queue: progress_queue.put(1) except Exception as e: logging.error(f'发生错误: {e}') if completion_label: completion_label.config(text="更新代码表失败!", foreground="red") def fetch_and_save_stock_basic_info(self): """获取个股基础信息并保存到baseinfo目录""" try: log_process("开始获取个股基础信息...") # 获取Tushare API实例 pro = self.account_manager.get_next_account(force_primary=True) # 强制使用主账户 # 获取股票基础信息 df = pro.stock_basic( exchange='', # 交易所代码,不指定则查询所有 list_status='L', # 上市状态:L上市,D退市,P暂停上市 fields='ts_code,symbol,name,area,industry,fullname,enname,market,exchange,curr_type,list_status,list_date,delist_date,is_hs' ) if df is None or df.empty: log_failure("未获取到个股基础信息") return False # 创建baseinfo目录 baseinfo_dir = os.path.join(Config.BASE_DIR, 'baseinfo') os.makedirs(baseinfo_dir, exist_ok=True) # 保存数据到文件,使用TXT格式而不是CSV output_file = os.path.join(baseinfo_dir, 'stock_basic_info.txt') self.save_to_txt(df, output_file) log_success(f"个股基础信息已成功保存到: {output_file}") log_result(f"共获取到 {len(df)} 条个股基础信息") return True except Exception as e: log_failure(f"获取个股基础信息失败: {str(e)}") log_failure(f"获取个股基础信息失败: {str(e)}") return False def process_stock_codes_batch(self, codes: list): """批量处理股票代码(每次最多4000个)""" pro = self.account_manager.get_next_account() processed_count = 0 failed_count = 0 try: # 分批处理,每批400个(考虑单次返回限制) batch_size = 400 total_batches = (len(codes) + batch_size - 1) // batch_size for batch_idx in range(0, len(codes), batch_size): batch = codes[batch_idx:batch_idx + batch_size] try: # 获取批量数据 df = self.fetch_data_with_retry(pro.daily, ts_code=','.join(batch)) if df is not None and not df.empty: # 按股票代码拆分数据并保存 for code in batch: try: # 筛选当前股票的数据 code_df = df[df['ts_code'] == code] if not code_df.empty: # 保存到文件 output_file = os.path.join(Config.OUTPUT_DIR, f"{code}_daily_data.txt") if self.save_to_txt(code_df, output_file): # 更新到数据库 if self.update_to_db(code_df, 'stock_daily'): processed_count += 1 else: failed_count += 1 log_warning(f"数据库更新失败,但文件已保存: {code}") else: failed_count += 1 log_failure(f"文件保存失败: {code}") else: log_warning(f"未获取到股票 {code} 的数据") failed_count += 1 except Exception as code_e: failed_count += 1 log_failure(f"处理股票 {code} 时出错: {str(code_e)}") else: failed_count += len(batch) log_warning(f"批次 {batch_idx//batch_size + 1}/{total_batches} 返回空数据") # 显示批次进度 log_result(f"批次 {batch_idx//batch_size + 1}/{total_batches} 处理完成 | 成功: {processed_count} | 失败: {failed_count}") except (ConnectionError, TimeoutError) as e: failed_count += len(batch) log_failure(f"批次 {batch_idx//batch_size + 1} 网络异常: {str(e)}") # 更换账户后重试 pro = self.account_manager.get_next_account() except Exception as batch_e: failed_count += len(batch) log_failure(f"批次 {batch_idx//batch_size + 1} 处理异常: {str(batch_e)}") # 记录总体处理结果 log_result(f"批量处理完成 | 总成功: {processed_count} | 总失败: {failed_count}") return {'success': processed_count, 'failed': failed_count} except Exception as e: log_failure(f"批量处理主流程失败: {str(e)}") return {'success': processed_count, 'failed': failed_count + (len(codes) - processed_count - failed_count)} finally: # 确保所有连接正确管理 # 添加返回连接方法的调用 try: if hasattr(self, 'db_conn') and self.db_conn: self.db_conn.close() self.db_conn = None log_message("数据库连接已关闭") except Exception as e: log_failure(f"关闭数据库连接时出错: {e}") def read_from_txt(self, file_path): """从TXT文件读取行情数据 - 内存优化版本""" try: # 使用chunksize分块读取大文件 chunks = [] for chunk in pd.read_csv(file_path, sep='\t', encoding='utf-8', chunksize=10000): # 对每一块进行预处理 chunk.columns = chunk.columns.str.lower() if 'trade_date' in chunk.columns: chunk['trade_date'] = pd.to_datetime(chunk['trade_date'], format='%Y%m%d') chunks.append(chunk) if chunks: df = pd.concat(chunks, ignore_index=True) return df else: return pd.DataFrame() except Exception as e: log_failure(f"读取文件 {file_path} 失败: {e}") return None # 添加新的方法用于优化数据库存储 def optimize_database_storage(self): """优化数据库存储结构""" engine = None try: engine = create_engine( f"mysql+pymysql://{Config.DB_CONFIG['user']}:{Config.DB_CONFIG['password']}@" f"{Config.DB_CONFIG['host']}:{Config.DB_CONFIG['port']}/{Config.DB_CONFIG['database']}?charset=utf8mb4", pool_pre_ping=True ) with engine.begin() as connection: # 添加索引以提高查询速度 connection.execute(text("ALTER TABLE stock_data_combined ADD INDEX IF NOT EXISTS idx_trade_date (trade_date)")) connection.execute(text("ALTER TABLE stock_data_combined ADD INDEX IF NOT EXISTS idx_stock_code (stock_code)")) # 如果有重复数据,添加去重逻辑 connection.execute(text(""" DELETE s1 FROM stock_data_combined s1 INNER JOIN stock_data_combined s2 WHERE s1.id > s2.id AND s1.trade_date = s2.trade_date AND s1.stock_code = s2.stock_code """)) log_message("数据库存储优化完成") return True except Exception as e: log_failure(f"数据库存储优化失败: {e}") return False finally: if engine: try: engine.dispose() log_message(f"SQLAlchemy引擎已释放") except Exception as e: log_failure(f"释放SQLAlchemy引擎失败: {e}") class ConsoleDataDownloader: """控制台模式数据下载器""" def __init__(self): self.downloader = DataDownloader() def run_with_choice(self, choice): """使用指定的选项运行""" try: if choice == '0': log_message("程序正在退出...") sys.exit(0) if choice == '6': # 新增全部工作逻辑 log_process("开始执行全部工作...") start_time = time.time() # 1. 更新股票代码 self.downloader.get_all_stock_codes() # 2. 更新指数数据 self.downloader.fetch_and_save_index_data('000001.SH') # 3. 更新个股数据 self.process_stock_codes() # 4. 同步数据库 if not self.downloader.connect_db(): log_failure("数据库连接失败,跳过数据库同步") else: self.downloader.update_database() log_success(f"全部工作完成! 总耗时: {time.time() - start_time:.2f}秒") return # 检查数据库连接是否正常(仅当选择数据库相关操作时) if choice == '5' and not self.downloader.connect_db(): log_failure("数据库连接失败,请检查配置后重试") return # 处理选项7:拷贝数据 if choice == '7': self.copy_data_to_target_directory() return # 处理选项8:检查数据完整性 if choice == '8': self.check_data_integrity() return # 处理选项9:获取个股基础信息 if choice == '9': self.downloader.fetch_and_save_stock_basic_info() return tasks = [] if choice in ('1', '4'): tasks.append(self.downloader.get_all_stock_codes) if choice in ('2', '4'): tasks.append(lambda: self.downloader.fetch_and_save_index_data('000001.SH')) if choice in ('3', '4'): tasks.append(self.process_stock_codes) if choice == '5': # 新增数据库更新任务 tasks.append(self.downloader.update_database) if not tasks: log_warning("无效选项") return log_process("开始执行任务...") start_time = time.time() for task in tasks: task() log_success(f"任务完成! 总耗时: {time.time() - start_time:.2f}秒") except Exception as e: log_failure(f"执行任务时发生异常: {e}") log_failure(f"执行任务时发生异常: {e}") def run(self): try: print("请选择要执行的操作:") print("1. 更新股票代码表") print("2. 更新指数数据") print("3. 更新个股数据") print("4. 全部更新") print("5. 同步数据库") # 新增选项 print("6. 全部工作(下载+同步数据库)") # 新增选项 print("7. 拷贝更新后的数据到目标目录") # 新增选项 print("8. 检查数据完整性") # 新增选项 print("9. 获取个股基础信息") # 新增选项 print("0. 退出") while True: try: choice = input("请输入选项: ").strip() self.run_with_choice(choice) except KeyboardInterrupt: print("\n程序已中断") return except Exception as e: log_failure(f"用户交互过程中发生异常: {e}") log_failure(f"操作异常: {e}") if choice == '6': # 新增全部工作逻辑 log_process("开始执行全部工作...") start_time = time.time() # 1. 更新股票代码 self.downloader.get_all_stock_codes() # 2. 更新指数数据 self.downloader.fetch_and_save_index_data('000001.SH') # 3. 更新个股数据 self.process_stock_codes() # 4. 同步数据库 if not self.downloader.connect_db(): log_failure("数据库连接失败,跳过数据库同步") else: self.downloader.update_database() log_success(f"全部工作完成! 总耗时: {time.time() - start_time:.2f}秒") continue # 检查数据库连接是否正常(仅当选择数据库相关操作时) # 检查数据库连接是否正常(仅当选择数据库相关操作时) if choice == '5' and not self.downloader.connect_db(): log_failure("数据库连接失败,请检查配置后重试") continue # 处理选项7:拷贝数据 if choice == '7': self.copy_data_to_target_directory() continue tasks = [] if choice in ('1', '4'): tasks.append(self.downloader.get_all_stock_codes) if choice in ('2', '4'): tasks.append(lambda: self.downloader.fetch_and_save_index_data('000001.SH')) if choice in ('3', '4'): tasks.append(self.process_stock_codes) if choice == '5': # 新增数据库更新任务 tasks.append(self.downloader.update_database) if not tasks: log_warning("无效选项,请重新输入") continue log_process("开始执行任务...") start_time = time.time() completed = 0 failed = 0 for task in tasks: try: task() completed += 1 except (ConnectionError, TimeoutError) as e: failed += 1 log_failure(f"任务执行网络异常: {str(e)}") except Exception as e: failed += 1 log_failure(f"任务执行异常: {str(e)}") log_success(f"任务完成! 成功: {completed}, 失败: {failed}, 总耗时: {time.time()-start_time:.2f}秒") except (ConnectionError, TimeoutError) as e: log_failure(f"用户交互过程中发生网络异常: {str(e)}") log_failure(f"网络异常: {str(e)}") except Exception as e: log_failure(f"用户交互过程中发生异常: {str(e)}") log_failure(f"操作异常: {str(e)}") except KeyboardInterrupt: log_warning("用户中断程序") print("程序已中断") except Exception as e: log_failure(f"程序运行异常: {str(e)}") def copy_data_to_target_directory(self): """拷贝D:\gp_data\day目录下的数据到D:\gp_data\floor_cl\data目录""" import os import shutil import time # 使用Config.BASE_DIR作为基础路径,保持代码一致性 source_dir = os.path.join(Config.BASE_DIR, 'day') target_dir = 'D:\\data\\jq_hc\\data' log_process(f"准备从 {source_dir} 拷贝数据到 {target_dir}") try: # 确保目标目录存在 if not os.path.exists(target_dir): os.makedirs(target_dir, exist_ok=True) log_message(f"已创建目标目录: {target_dir}") # 检查源目录是否存在 if not os.path.exists(source_dir): log_failure(f"源目录 {source_dir} 不存在,请检查路径") return # 获取源目录中的txt文件列表 files = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f)) and f.lower().endswith('.txt')] total_files = len(files) if total_files == 0: log_warning("源目录中没有找到文件") return log_process(f"开始拷贝 {total_files} 个文件...") start_time = time.time() copied_count = 0 failed_count = 0 for i, file_name in enumerate(files, 1): source_path = os.path.join(source_dir, file_name) target_path = os.path.join(target_dir, file_name) try: # 拷贝文件并覆盖已存在的文件 shutil.copy2(source_path, target_path) copied_count += 1 # 显示进度 progress = (i / total_files) * 100 elapsed = time.time() - start_time log_process(f"\r进度: [{('#' * int(progress / 2)):50}] {progress:.1f}% | {file_name}", end='', flush=True) except Exception as e: failed_count += 1 log_failure(f"\n拷贝文件 {file_name} 失败: {str(e)}") # 打印完成信息 log_success(f"\n拷贝完成!") log_result(f"成功: {copied_count}, 失败: {failed_count}") log_result(f"总耗时: {time.time() - start_time:.2f}秒") except KeyboardInterrupt: log_warning("\n拷贝操作已被用户中断") except Exception as e: log_failure(f"\n拷贝过程中发生异常: {str(e)}") def check_data_integrity(self): """检查数据完整性,运行check_market_data.py程序""" import subprocess import os try: log_process("开始检查数据完整性...") # 构建check_market_data.py的完整路径 script_path = os.path.join(os.path.dirname(__file__), 'check_market_data.py') # 运行check_market_data.py程序,直接显示输出 result = subprocess.run( [sys.executable, script_path], stdout=None, # 不捕获标准输出,直接显示在控制台上 stderr=None, # 不捕获错误输出,直接显示在控制台上 encoding='utf-8' ) if result.returncode == 0: log_success(f"\n数据完整性检查完成!") else: log_failure(f"\n数据完整性检查失败,返回码: {result.returncode}") except FileNotFoundError: log_failure(f"\n未找到check_market_data.py文件,请检查路径: {script_path}") except Exception as e: log_failure(f"\n执行数据完整性检查时发生异常: {e}") def process_stock_codes(self): """处理所有股票代码 - 优化版本""" try: with open(Config.INPUT_FILE, 'r', encoding='utf-8') as f: stock_codes = [line.strip().split('\t')[0] for line in f if line.strip()] except Exception as e: log_failure(f"读取股票代码文件失败: {e}") return total = len(stock_codes) completed = 0 start_time = time.time() def print_progress(): progress = (completed / total) * 100 elapsed = time.time() - start_time # 单行进度显示 log_process( f"\r进度: [{'#' * int(progress / 2)}{' ' * (50 - int(progress / 2))}] {progress:.1f}% | 已完成: {completed}/{total} | 耗时: {elapsed:.1f}s", end='', flush=True) # 按批次处理,避免一次性创建过多线程 batch_size = Config.MAX_THREADS * 2 for i in range(0, total, batch_size): batch = stock_codes[i:i + batch_size] with ThreadPoolExecutor(max_workers=Config.MAX_THREADS) as executor: futures = { executor.submit(self.downloader.process_stock_code, code, None): code for code in batch } for future in as_completed(futures): completed += 1 if completed % 5 == 0 or completed == total: print_progress() future.result() log_success(f"\r进度: [{('#' * 50)}] 100.0% | 已完成: {total}/{total} | 总耗时: {time.time() - start_time:.1f}s") log_success(f"\n个股数据下载完成!") # 创建全局AccountManager实例 _global_account_manager = AccountManager() @lru_cache(maxsize=1000) def get_cached_data(code): """缓存股票数据查询结果""" pro = _global_account_manager.get_next_account() # 使用全局AccountManager实例 return pro.daily(ts_code=code) def parse_args(): """解析命令行参数""" parser = argparse.ArgumentParser(description='Tushare数据更新工具') parser.add_argument('-c', '--choice', type=str, choices=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], help='要执行的操作: 0=退出, 1=更新股票代码表, 2=更新指数数据, 3=更新个股数据, 4=全部更新, 5=同步数据库, 6=全部工作, 7=拷贝数据到目标目录, 8=检查数据完整性, 9=获取个股基础信息') return parser.parse_args() if __name__ == "__main__": args = parse_args() downloader = ConsoleDataDownloader() if args.choice: # 使用命令行参数直接执行操作 downloader.run_with_choice(args.choice) else: # 进入交互式模式 downloader.run()