Files
update_day/update_tushare_totxt.py

1189 lines
54 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from threading import Thread
import os
import time
import logging
import threading
import socket
import sys
from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed
import pandas as pd
import tushare as ts
from log_style import setup_logger, get_logger
from functools import lru_cache
import argparse
# 配置日志记录器
logger = setup_logger(
name='update_tushare_totxt',
log_file='app.log',
level=logging.INFO,
log_format='%(asctime)s - %(levelname)s - %(message)s',
console=True
)
def log_warning(message, **kwargs):
"""警告信息,黄色显示"""
logger.warning(f"\033[33m{message}\033[0m", **kwargs)
def log_failure(message, **kwargs):
"""失败信息,红色显示"""
logger.error(f"\033[31m{message}\033[0m", **kwargs)
# 定义带颜色的日志函数
def log_success(message, **kwargs):
"""成功信息,绿色显示"""
logger.info(f"\033[32m{message}\033[0m", **kwargs)
def log_failure(message, **kwargs):
"""失败信息,红色显示"""
logger.error(f"\033[31m{message}\033[0m", **kwargs)
def log_process(message, **kwargs):
"""处理中信息,黄色显示"""
logger.info(f"\033[33m{message}\033[0m", **kwargs)
def log_result(message, **kwargs):
"""结果信息,蓝色显示"""
logger.info(f"\033[34m{message}\033[0m", **kwargs)
def log_message(message, **kwargs):
"""普通信息,白色显示"""
logger.info(f"\033[37m{message}\033[0m", **kwargs)
# 数据库相关导入
import pymysql
from sqlalchemy import create_engine, text
class Config:
"""全局配置类"""
BASE_DIR = 'D:\\gp_data'
INPUT_FILE = os.path.join(BASE_DIR, 'code', 'all_stock_codes.txt')
OUTPUT_DIR = os.path.join(BASE_DIR, 'day')
INDEX_DIR = os.path.join(BASE_DIR, 'index')
# 最大线程数可通过环境变量MAX_THREADS设置默认10
MAX_THREADS = int(os.getenv('MAX_THREADS', 20))
# 每分钟最大请求次数限制可通过环境变量REQUEST_LIMIT设置默认200
REQUEST_LIMIT = int(os.getenv('REQUEST_LIMIT', 500))
# 请求失败时的最大重试次数
MAX_RETRIES = 3
# 时间常量表示60秒(1分钟),用于请求频率控制
ONE_MINUTE = 60
# 多账户Token数组从环境变量读取如果没有则使用默认值
ACCOUNTS = [
os.getenv('TUSHARE_TOKEN1', '9343e641869058684afeadfcfe7fd6684160852e52e85332a7734c8d'), # 主账户
os.getenv('TUSHARE_TOKEN2', 'b156282fb2328afa056346736a93778861ac6be97a485f4d1f55c493') # 备用账户1
]
# 新增数据库重试配置
DB_MAX_RETRIES = 3 # 最大重试次数
DB_RETRY_INTERVAL = 3 # 重试间隔(秒)
# 新增数据库配置,从环境变量读取敏感信息
DB_CONFIG = {
'host': os.getenv('DB_HOST', '127.0.0.1'),
'port': int(os.getenv('DB_PORT', 3306)),
'user': os.getenv('DB_USER', 'root'),
'password': os.getenv('DB_PASSWORD', 'stella0850'),
'database': os.getenv('DB_NAME', 'stock_data')
}
# 数据库连接池配置
DB_POOL_SIZE = 20 # 连接池大小
DB_MAX_OVERFLOW = 10 # 最大溢出连接数
class AccountManager:
# 类变量缓存,实现跨实例的缓存共享
_stock_basic_cache = None
_stock_basic_cache_time = 0
def __init__(self):
"""初始化账户管理器"""
self.accounts = [ts.pro_api(token) for token in Config.ACCOUNTS]
self.current_index = 0
self.lock = threading.Lock()
# 为每个账户单独维护请求时间,提高并发效率
self.last_request_time_per_account = [time.time() for _ in self.accounts]
def get_stock_basic(self, force_primary=False):
"""获取股票基本信息,带缓存机制"""
current_time = time.time()
# 使用类变量缓存,实现跨实例的缓存共享
# 缓存有效期为30分钟
if (AccountManager._stock_basic_cache is not None and
current_time - AccountManager._stock_basic_cache_time < 30 * 60):
return AccountManager._stock_basic_cache
# 缓存失效,重新获取
pro = self.get_next_account(force_primary)
df_sse = pro.stock_basic(exchange='SSE', list_status='L')
df_szse = pro.stock_basic(exchange='SZSE', list_status='L')
AccountManager._stock_basic_cache = pd.concat([df_sse, df_szse], ignore_index=True)
AccountManager._stock_basic_cache_time = current_time
return AccountManager._stock_basic_cache
def get_next_account(self, force_primary=False):
"""轮询获取下一个账户
force_primary: 强制使用主账户
"""
with self.lock:
try:
if force_primary:
return self.accounts[0] # 强制返回主账户
account = self.accounts[self.current_index]
self.current_index = (self.current_index + 1) % len(self.accounts)
return account
except (ConnectionError, TimeoutError) as e:
# 网络连接相关异常
log_failure(f"账户网络连接异常: {str(e)}")
except Exception as e:
# 其他异常
log_failure(f"获取账户异常: {str(e)}")
# 无论什么异常,都尝试下一个账户
self.current_index = (self.current_index + 1) % len(self.accounts)
# 防止无限递归,设置最大尝试次数
retry_count = getattr(self, '_retry_count', 0)
if retry_count >= len(self.accounts) * 2: # 尝试每个账户2次
setattr(self, '_retry_count', 0)
log_failure("所有账户都尝试失败,返回最后一个账户作为尝试")
return self.accounts[0] # 返回第一个账户作为最后的尝试
setattr(self, '_retry_count', retry_count + 1)
return self.get_next_account(force_primary)
class DataDownloader:
"""数据下载器核心类"""
def __init__(self):
self.account_manager = AccountManager()
self.setup_logging()
self.create_dirs()
self.last_request_time = time.time()
self.db_conn = None # 新增数据库连接属性
self.total_files = 0
self.processed_files = 0
self.progress = {'current': 0, 'total': 100, 'message': '准备开始'}
# 数据库连接池管理
self.conn_pool = {}
self.conn_count = 0
self.max_connections = 5 # 最大连接数限制
self.active_connections = set()
def show_progress(self, current, total, start_time):
"""进度显示"""
self.progress = {
'current': current,
'total': total,
'message': f'处理中: {current / total * 100:.1f}% | 耗时: {time.time() - start_time:.1f}s'
}
progress = current / total * 100
elapsed = time.time() - start_time
log_process(f"\r处理中: {progress:.1f}% ({current}/{total}) | 耗时: {elapsed:.1f}s", end='', flush=True)
if current == total:
log_success(f"\n数据处理完成!") # 添加换行符和完成信息
# def show_progress(self, current, total, start_time):
# """更频繁的进度显示"""
# if current == total or current % 2 == 0:
# progress = current / total * 100
# elapsed = time.time() - start_time
# print(f"\r处理中: {progress:.1f}% ({current}/{total}) | 耗时: {elapsed:.1f}s", end='')
# if current == total:
# print(" 完成!")
def setup_logging(self):
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('app.log')
]
)
def connect_db(self):
"""连接数据库,带重试机制和自动创建数据库功能"""
# 检查连接池是否有可用连接
if self.conn_pool:
# 使用第一个可用连接
conn_id, conn = next(iter(self.conn_pool.items()))
try:
# 测试连接是否有效
cursor = conn.cursor()
cursor.execute("SELECT 1")
cursor.close()
del self.conn_pool[conn_id]
self.active_connections.add(conn_id)
log_message(f"复用数据库连接 {conn_id}")
self.db_conn = conn
return True
except Exception:
# 连接无效,移除并继续创建新连接
try:
conn.close()
except:
pass
del self.conn_pool[conn_id]
log_warning(f"移除无效连接 {conn_id}")
def return_connection(self, conn=None):
"""返回连接到连接池或关闭连接"""
# 如果提供了连接,使用提供的连接
if conn:
try:
# 测试连接是否仍然有效
cursor = conn.cursor()
cursor.execute("SELECT 1")
cursor.close()
# 生成唯一连接ID
conn_id = f"conn_{int(time.time() * 1000)}"
# 从活动连接中移除并添加到连接池
if hasattr(self, 'active_connections'):
for active_conn_id in list(self.active_connections):
if active_conn_id.endswith(str(hash(str(conn)))):
self.active_connections.remove(active_conn_id)
# 添加到连接池
self.conn_pool[conn_id] = conn
log_message(f"连接已返回连接池: {conn_id}")
return True
except Exception:
# 连接已无效,关闭它
try:
conn.close()
if hasattr(self, 'conn_count'):
self.conn_count = max(0, self.conn_count - 1)
log_message(f"无效连接已关闭,当前连接数: {self.conn_count}")
except:
pass
return False
# 如果使用的是对象的默认连接
elif hasattr(self, 'db_conn') and self.db_conn:
try:
# 测试连接是否仍然有效
cursor = self.db_conn.cursor()
cursor.execute("SELECT 1")
cursor.close()
# 生成唯一连接ID
conn_id = f"conn_{int(time.time() * 1000)}"
# 从活动连接中移除并添加到连接池
for active_conn_id in list(self.active_connections):
self.active_connections.remove(active_conn_id)
# 添加到连接池
self.conn_pool[conn_id] = self.db_conn
self.db_conn = None # 清除对象属性
log_message(f"默认连接已返回连接池: {conn_id}")
return True
except Exception:
# 连接已无效,关闭它
try:
self.db_conn.close()
self.db_conn = None
self.conn_count = max(0, self.conn_count - 1)
log_message(f"默认无效连接已关闭,当前连接数: {self.conn_count}")
except:
pass
return False
return False
# 如果连接数超过限制,等待一段时间
while self.conn_count >= self.max_connections:
log_warning(f"数据库连接数达到上限 {self.max_connections},等待释放...")
time.sleep(1)
for attempt in range(Config.DB_MAX_RETRIES):
try:
# 先尝试连接指定数据库
conn = pymysql.connect(**Config.DB_CONFIG)
conn_id = f"conn_{int(time.time() * 1000)}"
self.conn_count += 1
self.active_connections.add(conn_id)
log_message(f"创建数据库连接 {conn_id},当前连接数: {self.conn_count}")
self.db_conn = conn
return True # 静默返回,不记录成功日志
except pymysql.err.OperationalError as e:
if e.args[0] == 1049: # 数据库不存在错误码
try:
# 连接MySQL服务器(不带数据库名)
temp_config = Config.DB_CONFIG.copy()
temp_config.pop('database')
temp_conn = pymysql.connect(**temp_config)
# 创建数据库(使用Config.DB_CONFIG中的database名)
with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {Config.DB_CONFIG['database']} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
temp_conn.commit()
temp_conn.close()
log_message(f"数据库 {Config.DB_CONFIG['database']} 创建成功")
# 重新尝试连接
continue
except Exception as create_e:
log_failure(f"创建数据库失败: {create_e}")
return False
elif attempt < Config.DB_MAX_RETRIES - 1: # 不是最后一次尝试
log_warning(f"数据库连接失败(尝试 {attempt + 1}/{Config.DB_MAX_RETRIES}): {e}")
time.sleep(Config.DB_RETRY_INTERVAL)
else:
log_failure(f"数据库连接失败,已达到最大重试次数 {Config.DB_MAX_RETRIES}")
return False
except Exception as e:
if attempt < Config.DB_MAX_RETRIES - 1: # 不是最后一次尝试
log_warning(f"数据库连接失败(尝试 {attempt + 1}/{Config.DB_MAX_RETRIES}): {e}")
time.sleep(Config.DB_RETRY_INTERVAL)
else:
log_failure(f"数据库连接失败,已达到最大重试次数 {Config.DB_MAX_RETRIES}")
return False
def update_to_db(self, data: pd.DataFrame, table_name: str):
"""线程安全的数据库更新方法"""
engine = None
try:
# 每个线程创建自己的引擎和连接,使用更好的连接池配置
engine = create_engine(
f"mysql+pymysql://{Config.DB_CONFIG['user']}:{Config.DB_CONFIG['password']}@"
f"{Config.DB_CONFIG['host']}:{Config.DB_CONFIG['port']}/{Config.DB_CONFIG['database']}?charset=utf8mb4",
pool_size=Config.MAX_THREADS, # 设置连接池大小
max_overflow=Config.DB_MAX_OVERFLOW,
pool_timeout=30, # 连接超时时间
pool_recycle=3600, # 连接回收时间
pool_pre_ping=True, # 连接前检查有效性
echo=False,
connect_args={
"charset": "utf8mb4",
"autocommit": True, # 自动提交
"connect_timeout": 30,
"read_timeout": 30,
"write_timeout": 30
}
)
# 强制表名小写并移除特殊字符
table_name = table_name.lower().replace('.', '_')
# 添加表结构检查
if table_name.startswith('stock_'):
if 'trade_date' not in data.columns:
log_failure(f"数据缺少必要列 trade_date")
return False
if not pd.api.types.is_datetime64_any_dtype(data['trade_date']):
data['trade_date'] = pd.to_datetime(data['trade_date'], format='%Y%m%d')
with engine.begin() as connection: # 使用事务
data.to_sql(table_name, connection, if_exists='append', index=False, chunksize=1000)
return True
except Exception as e:
log_failure(f"数据库更新失败: {e}")
return False
finally:
# 确保资源释放
if engine:
try:
engine.dispose()
log_message(f"SQLAlchemy引擎已释放")
except Exception as e:
log_failure(f"释放SQLAlchemy引擎失败: {e}")
def update_database(self, progress_queue=None, completion_label=None):
try:
log_process("开始更新数据库...")
start_time = time.time()
# 1. 更新股票代码表(使用缓存机制)
log_process("正在更新股票代码表...")
df_all = self.account_manager.get_stock_basic(force_primary=True)
self.update_to_db(df_all, 'stock_basic')
# 2. 从本地文件更新日线数据
log_process("正在更新个股日线数据...")
files = [f for f in os.listdir(Config.OUTPUT_DIR) if f.endswith('_daily_data.txt')]
self.total_files = len(files)
self.processed_files = 0
if self.total_files > 0:
# 使用批量插入优化数据库写入
self.batch_update_stock_data(files, start_time)
else:
log_message("没有找到需要处理的个股数据文件")
# 初始显示进度0%
# self.show_progress(0, self.total_files, start_time)
#
# # 使用批量插入优化数据库写入
# self.batch_update_stock_data(files, start_time)
# for i, file in enumerate(files, 1):
# file_path = os.path.join(Config.OUTPUT_DIR, file)
# df = self.read_from_txt(file_path)
# if df is not None:
# code = file.split('_')[0].lower().replace('.', '_')
# self.update_to_db(df, f'stock_{code}')
#
# self.processed_files = i
# self.show_progress(i, self.total_files, start_time)
# 3. 更新指数数据(从本地文件)
log_process("正在更新指数数据...")
index_file = os.path.join(Config.INDEX_DIR, '000001.SH.txt')
if os.path.exists(index_file):
df_index = self.read_from_txt(index_file)
if df_index is not None:
self.update_to_db(df_index, 'index_daily')
# 4. 更新个股数据(分批处理)
log_process("正在更新实时个股数据...")
stock_codes = df_all['ts_code'].tolist()
batch_size = 400
for i in range(0, len(stock_codes), batch_size):
batch = stock_codes[i:i + batch_size]
df = self.fetch_data_with_retry(pro.daily, ts_code=','.join(batch))
if df is not None:
self.update_to_db(df, 'stock_daily')
log_success(f"数据库更新完成! 总耗时: {time.time() - start_time:.2f}")
if completion_label:
completion_label.config(text="数据库更新完成!", foreground="green")
if progress_queue:
progress_queue.put(1)
return True
except Exception as e:
log_failure(f"数据库更新失败: {e}")
if completion_label:
completion_label.config(text="数据库更新失败!", foreground="red")
return False
def batch_update_stock_data(self, files, start_time):
"""批量更新股票数据到数据库以提高性能 - 优化版本"""
engine = None
try:
# 创建一个共享的数据库引擎,使用更好的连接池配置
engine = create_engine(
f"mysql+pymysql://{Config.DB_CONFIG['user']}:{Config.DB_CONFIG['password']}@"
f"{Config.DB_CONFIG['host']}:{Config.DB_CONFIG['port']}/{Config.DB_CONFIG['database']}?charset=utf8mb4",
pool_size=Config.MAX_THREADS,
max_overflow=Config.DB_MAX_OVERFLOW,
pool_timeout=30,
pool_recycle=3600,
pool_pre_ping=True,
echo=False # 关闭SQL日志以提高性能
)
# 批量处理文件每批处理100个文件
batch_size = 100
processed_count = 0
total_files = len(files)
for i in range(0, len(files), batch_size):
batch_files = files[i:i + batch_size]
all_data = []
# 收集一批数据
for file in batch_files:
file_path = os.path.join(Config.OUTPUT_DIR, file)
df = self.read_from_txt(file_path)
if df is not None:
code = file.split('_')[0].lower().replace('.', '_')
df['stock_code'] = code # 添加股票代码标识
all_data.append(df)
# 更新单个文件处理进度
processed_count += 1
self.show_progress(processed_count, total_files, start_time)
# 合并数据并批量插入
if all_data:
combined_df = pd.concat(all_data, ignore_index=True)
# 使用事务批量插入
with engine.begin() as connection:
# 先删除重复数据(如果有需要)
# 然后批量插入新数据
combined_df.to_sql('stock_data_combined', connection, if_exists='append', index=False,
chunksize=2000, method='multi')
# 确保最终进度显示为100%
if total_files > 0:
self.show_progress(total_files, total_files, start_time)
except Exception as e:
log_failure(f"批量更新股票数据失败: {e}")
finally:
# 确保资源释放
if engine:
try:
engine.dispose()
log_message(f"批量更新的SQLAlchemy引擎已释放")
except Exception as e:
log_failure(f"释放批量更新的SQLAlchemy引擎失败: {e}")
def create_dirs(self):
# 创建基础目录
os.makedirs(Config.BASE_DIR, exist_ok=True)
# 创建代码文件目录
os.makedirs(os.path.dirname(Config.INPUT_FILE), exist_ok=True)
# 创建输出目录
os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
# 创建指数目录
os.makedirs(Config.INDEX_DIR, exist_ok=True)
# 优化 fetch_data_with_retry 方法,改进请求频率控制
def fetch_data_with_retry(self, func, *args, **kwargs):
"""带重试机制的数据获取方法 - 优化版本"""
# 计算请求频率限制(总请求数/分钟)
total_requests_per_minute = Config.REQUEST_LIMIT * len(Config.ACCOUNTS)
requests_per_second = total_requests_per_minute / 60.0
min_interval = 1.0 / requests_per_second
for attempt in range(Config.MAX_RETRIES):
try:
# 精确控制请求间隔
current_time = time.time()
time_since_last_request = current_time - self.last_request_time
# 确保请求间隔不小于最小限制
if time_since_last_request < min_interval:
time.sleep(min_interval - time_since_last_request)
result = func(*args, **kwargs)
self.last_request_time = time.time()
return result if result is not None and not (hasattr(result, 'empty') and result.empty) else None
except Exception as e:
wait_time = min(2 ** attempt, 5)
time.sleep(wait_time)
def save_to_txt(self, data: pd.DataFrame, filename: str) -> bool:
try:
# 按交易日期降序排序,确保最新交易日排在最前面
if 'trade_date' in data.columns:
data = data.sort_values('trade_date', ascending=False)
data.to_csv(filename, index=False, sep='\t', encoding='utf-8')
# logging.info(f"数据已保存到 {filename}")
return True
except Exception as e:
log_failure(f"保存文件时出错: {e}")
return False
def process_stock_code(self, code, progress_queue=None): # 修改参数默认值为None
pro = self.account_manager.get_next_account()
try:
output_file = os.path.join(Config.OUTPUT_DIR, f"{code}_daily_data.txt")
# 检查是否存在现有数据文件
if os.path.exists(output_file):
# 读取现有数据,获取最新的交易日期
existing_df = self.read_from_txt(output_file)
if existing_df is not None and not existing_df.empty:
# 获取最新交易日期
if 'trade_date' in existing_df.columns:
# 由于read_from_txt会将trade_date转换为datetime格式
# 确保现有数据的trade_date列是datetime格式
if not pd.api.types.is_datetime64_any_dtype(existing_df['trade_date']):
existing_df['trade_date'] = pd.to_datetime(existing_df['trade_date'], format='%Y%m%d')
# 获取最新交易日期
latest_date_dt = existing_df['trade_date'].max()
# 计算下一个交易日的起始日期(避免重复获取同一天数据)
next_date_dt = latest_date_dt + timedelta(days=1)
next_date = next_date_dt.strftime('%Y%m%d')
# 获取最新日期之后的数据
df = self.fetch_data_with_retry(pro.daily, ts_code=code, start_date=next_date)
if df is not None and not df.empty:
# 将新数据的trade_date列转换为datetime格式以便合并
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')
# 合并现有数据和新数据
combined_df = pd.concat([existing_df, df], ignore_index=True)
# 去重,避免重复数据
combined_df = combined_df.drop_duplicates(subset=['trade_date', 'ts_code'], keep='last')
# 按交易日期降序排序,最新交易日排在最前面
combined_df = combined_df.sort_values('trade_date', ascending=False)
# 将trade_date转换回字符串格式保存
combined_df['trade_date'] = combined_df['trade_date'].dt.strftime('%Y%m%d')
# 保存合并后的数据
self.save_to_txt(combined_df, output_file)
else:
# 如果现有数据没有 trade_date 列,重新获取全部数据
df = self.fetch_data_with_retry(pro.daily, ts_code=code)
if df is not None:
self.save_to_txt(df, output_file)
else:
# 现有数据为空,重新获取全部数据
df = self.fetch_data_with_retry(pro.daily, ts_code=code)
if df is not None:
self.save_to_txt(df, output_file)
else:
# 文件不存在,获取全部数据
df = self.fetch_data_with_retry(pro.daily, ts_code=code)
if df is not None:
self.save_to_txt(df, output_file)
if progress_queue is not None: # 添加判断
progress_queue.put(1)
except (ConnectionError, TimeoutError) as e:
log_failure(f"股票 {code} 网络连接异常: {str(e)}")
except pd.errors.EmptyDataError:
log_warning(f"股票 {code} 返回空数据")
except (pymysql.err.OperationalError, pymysql.err.InterfaceError) as e:
log_failure(f"股票 {code} 数据库操作异常: {str(e)}")
except Exception as e:
log_failure(f"处理股票代码 {code} 时出错: {str(e)}")
finally:
if progress_queue is not None: # 添加判断
progress_queue.put(1)
def fetch_and_save_index_data(self, ts_code='000001.SH', progress_queue=None, completion_label=None):
pro = self.account_manager.get_next_account(force_primary=True) # 强制使用主账户
try:
df = pro.index_daily(ts_code=ts_code)
if df is None or df.empty:
log_warning(f"未获取到指数 {ts_code} 的日线数据")
return False
os.makedirs(Config.INDEX_DIR, exist_ok=True)
output_file = os.path.join(Config.INDEX_DIR, f"{ts_code}.txt")
self.save_to_txt(df, output_file)
if completion_label:
completion_label.config(text="指数数据下载完成!", foreground="green")
if progress_queue:
progress_queue.put(1)
return True
except Exception as e:
log_failure(f"Error fetching index data for {ts_code}: {e}")
if completion_label:
completion_label.config(text="获取指数数据失败!", foreground="red")
return False
def get_all_stock_codes(self, progress_queue=None, completion_label=None):
try:
# 使用缓存的股票基本信息
df_all = self.account_manager.get_stock_basic(force_primary=True)
df_all[['ts_code', 'name']].to_csv(Config.INPUT_FILE, index=False, header=False, sep='\t')
log_success(f'所有资料已保存到 {Config.INPUT_FILE}')
if completion_label:
completion_label.config(text="代码表更新完成!", foreground="green")
if progress_queue:
progress_queue.put(1)
except Exception as e:
logging.error(f'发生错误: {e}')
if completion_label:
completion_label.config(text="更新代码表失败!", foreground="red")
def fetch_and_save_stock_basic_info(self):
"""获取个股基础信息并保存到baseinfo目录"""
try:
log_process("开始获取个股基础信息...")
# 获取Tushare API实例
pro = self.account_manager.get_next_account(force_primary=True) # 强制使用主账户
# 获取股票基础信息
df = pro.stock_basic(
exchange='', # 交易所代码,不指定则查询所有
list_status='L', # 上市状态L上市D退市P暂停上市
fields='ts_code,symbol,name,area,industry,fullname,enname,market,exchange,curr_type,list_status,list_date,delist_date,is_hs'
)
if df is None or df.empty:
log_failure("未获取到个股基础信息")
return False
# 创建baseinfo目录
baseinfo_dir = os.path.join(Config.BASE_DIR, 'baseinfo')
os.makedirs(baseinfo_dir, exist_ok=True)
# 保存数据到文件使用TXT格式而不是CSV
output_file = os.path.join(baseinfo_dir, 'stock_basic_info.txt')
self.save_to_txt(df, output_file)
log_success(f"个股基础信息已成功保存到: {output_file}")
log_result(f"共获取到 {len(df)} 条个股基础信息")
return True
except Exception as e:
log_failure(f"获取个股基础信息失败: {str(e)}")
log_failure(f"获取个股基础信息失败: {str(e)}")
return False
def process_stock_codes_batch(self, codes: list):
"""批量处理股票代码(每次最多4000个)"""
pro = self.account_manager.get_next_account()
processed_count = 0
failed_count = 0
try:
# 分批处理每批400个考虑单次返回限制
batch_size = 400
total_batches = (len(codes) + batch_size - 1) // batch_size
for batch_idx in range(0, len(codes), batch_size):
batch = codes[batch_idx:batch_idx + batch_size]
try:
# 获取批量数据
df = self.fetch_data_with_retry(pro.daily, ts_code=','.join(batch))
if df is not None and not df.empty:
# 按股票代码拆分数据并保存
for code in batch:
try:
# 筛选当前股票的数据
code_df = df[df['ts_code'] == code]
if not code_df.empty:
# 保存到文件
output_file = os.path.join(Config.OUTPUT_DIR, f"{code}_daily_data.txt")
if self.save_to_txt(code_df, output_file):
# 更新到数据库
if self.update_to_db(code_df, 'stock_daily'):
processed_count += 1
else:
failed_count += 1
log_warning(f"数据库更新失败,但文件已保存: {code}")
else:
failed_count += 1
log_failure(f"文件保存失败: {code}")
else:
log_warning(f"未获取到股票 {code} 的数据")
failed_count += 1
except Exception as code_e:
failed_count += 1
log_failure(f"处理股票 {code} 时出错: {str(code_e)}")
else:
failed_count += len(batch)
log_warning(f"批次 {batch_idx//batch_size + 1}/{total_batches} 返回空数据")
# 显示批次进度
log_result(f"批次 {batch_idx//batch_size + 1}/{total_batches} 处理完成 | 成功: {processed_count} | 失败: {failed_count}")
except (ConnectionError, TimeoutError) as e:
failed_count += len(batch)
log_failure(f"批次 {batch_idx//batch_size + 1} 网络异常: {str(e)}")
# 更换账户后重试
pro = self.account_manager.get_next_account()
except Exception as batch_e:
failed_count += len(batch)
log_failure(f"批次 {batch_idx//batch_size + 1} 处理异常: {str(batch_e)}")
# 记录总体处理结果
log_result(f"批量处理完成 | 总成功: {processed_count} | 总失败: {failed_count}")
return {'success': processed_count, 'failed': failed_count}
except Exception as e:
log_failure(f"批量处理主流程失败: {str(e)}")
return {'success': processed_count, 'failed': failed_count + (len(codes) - processed_count - failed_count)}
finally:
# 确保所有连接正确管理
# 添加返回连接方法的调用
try:
if hasattr(self, 'db_conn') and self.db_conn:
self.db_conn.close()
self.db_conn = None
log_message("数据库连接已关闭")
except Exception as e:
log_failure(f"关闭数据库连接时出错: {e}")
def read_from_txt(self, file_path):
"""从TXT文件读取行情数据 - 内存优化版本"""
try:
# 使用chunksize分块读取大文件
chunks = []
for chunk in pd.read_csv(file_path, sep='\t', encoding='utf-8', chunksize=10000):
# 对每一块进行预处理
chunk.columns = chunk.columns.str.lower()
if 'trade_date' in chunk.columns:
chunk['trade_date'] = pd.to_datetime(chunk['trade_date'], format='%Y%m%d')
chunks.append(chunk)
if chunks:
df = pd.concat(chunks, ignore_index=True)
return df
else:
return pd.DataFrame()
except Exception as e:
log_failure(f"读取文件 {file_path} 失败: {e}")
return None
# 添加新的方法用于优化数据库存储
def optimize_database_storage(self):
"""优化数据库存储结构"""
engine = None
try:
engine = create_engine(
f"mysql+pymysql://{Config.DB_CONFIG['user']}:{Config.DB_CONFIG['password']}@"
f"{Config.DB_CONFIG['host']}:{Config.DB_CONFIG['port']}/{Config.DB_CONFIG['database']}?charset=utf8mb4",
pool_pre_ping=True
)
with engine.begin() as connection:
# 添加索引以提高查询速度
connection.execute(text("ALTER TABLE stock_data_combined ADD INDEX IF NOT EXISTS idx_trade_date (trade_date)"))
connection.execute(text("ALTER TABLE stock_data_combined ADD INDEX IF NOT EXISTS idx_stock_code (stock_code)"))
# 如果有重复数据,添加去重逻辑
connection.execute(text("""
DELETE s1 FROM stock_data_combined s1
INNER JOIN stock_data_combined s2
WHERE s1.id > s2.id AND s1.trade_date = s2.trade_date AND s1.stock_code = s2.stock_code
"""))
log_message("数据库存储优化完成")
return True
except Exception as e:
log_failure(f"数据库存储优化失败: {e}")
return False
finally:
if engine:
try:
engine.dispose()
log_message(f"SQLAlchemy引擎已释放")
except Exception as e:
log_failure(f"释放SQLAlchemy引擎失败: {e}")
class ConsoleDataDownloader:
"""控制台模式数据下载器"""
def __init__(self):
self.downloader = DataDownloader()
def run_with_choice(self, choice):
"""使用指定的选项运行"""
try:
if choice == '0':
log_message("程序正在退出...")
sys.exit(0)
if choice == '6': # 新增全部工作逻辑
log_process("开始执行全部工作...")
start_time = time.time()
# 1. 更新股票代码
self.downloader.get_all_stock_codes()
# 2. 更新指数数据
self.downloader.fetch_and_save_index_data('000001.SH')
# 3. 更新个股数据
self.process_stock_codes()
# 4. 同步数据库
if not self.downloader.connect_db():
log_failure("数据库连接失败,跳过数据库同步")
else:
self.downloader.update_database()
log_success(f"全部工作完成! 总耗时: {time.time() - start_time:.2f}")
return
# 检查数据库连接是否正常(仅当选择数据库相关操作时)
if choice == '5' and not self.downloader.connect_db():
log_failure("数据库连接失败,请检查配置后重试")
return
# 处理选项7拷贝数据
if choice == '7':
self.copy_data_to_target_directory()
return
# 处理选项8检查数据完整性
if choice == '8':
self.check_data_integrity()
return
# 处理选项9获取个股基础信息
if choice == '9':
self.downloader.fetch_and_save_stock_basic_info()
return
tasks = []
if choice in ('1', '4'):
tasks.append(self.downloader.get_all_stock_codes)
if choice in ('2', '4'):
tasks.append(lambda: self.downloader.fetch_and_save_index_data('000001.SH'))
if choice in ('3', '4'):
tasks.append(self.process_stock_codes)
if choice == '5': # 新增数据库更新任务
tasks.append(self.downloader.update_database)
if not tasks:
log_warning("无效选项")
return
log_process("开始执行任务...")
start_time = time.time()
for task in tasks:
task()
log_success(f"任务完成! 总耗时: {time.time() - start_time:.2f}")
except Exception as e:
log_failure(f"执行任务时发生异常: {e}")
log_failure(f"执行任务时发生异常: {e}")
def run(self):
try:
print("请选择要执行的操作:")
print("1. 更新股票代码表")
print("2. 更新指数数据")
print("3. 更新个股数据")
print("4. 全部更新")
print("5. 同步数据库") # 新增选项
print("6. 全部工作(下载+同步数据库)") # 新增选项
print("7. 拷贝更新后的数据到目标目录") # 新增选项
print("8. 检查数据完整性") # 新增选项
print("9. 获取个股基础信息") # 新增选项
print("0. 退出")
while True:
try:
choice = input("请输入选项: ").strip()
self.run_with_choice(choice)
except KeyboardInterrupt:
print("\n程序已中断")
return
except Exception as e:
log_failure(f"用户交互过程中发生异常: {e}")
log_failure(f"操作异常: {e}")
if choice == '6': # 新增全部工作逻辑
log_process("开始执行全部工作...")
start_time = time.time()
# 1. 更新股票代码
self.downloader.get_all_stock_codes()
# 2. 更新指数数据
self.downloader.fetch_and_save_index_data('000001.SH')
# 3. 更新个股数据
self.process_stock_codes()
# 4. 同步数据库
if not self.downloader.connect_db():
log_failure("数据库连接失败,跳过数据库同步")
else:
self.downloader.update_database()
log_success(f"全部工作完成! 总耗时: {time.time() - start_time:.2f}")
continue
# 检查数据库连接是否正常(仅当选择数据库相关操作时)
# 检查数据库连接是否正常(仅当选择数据库相关操作时)
if choice == '5' and not self.downloader.connect_db():
log_failure("数据库连接失败,请检查配置后重试")
continue
# 处理选项7拷贝数据
if choice == '7':
self.copy_data_to_target_directory()
continue
tasks = []
if choice in ('1', '4'):
tasks.append(self.downloader.get_all_stock_codes)
if choice in ('2', '4'):
tasks.append(lambda: self.downloader.fetch_and_save_index_data('000001.SH'))
if choice in ('3', '4'):
tasks.append(self.process_stock_codes)
if choice == '5': # 新增数据库更新任务
tasks.append(self.downloader.update_database)
if not tasks:
log_warning("无效选项,请重新输入")
continue
log_process("开始执行任务...")
start_time = time.time()
completed = 0
failed = 0
for task in tasks:
try:
task()
completed += 1
except (ConnectionError, TimeoutError) as e:
failed += 1
log_failure(f"任务执行网络异常: {str(e)}")
except Exception as e:
failed += 1
log_failure(f"任务执行异常: {str(e)}")
log_success(f"任务完成! 成功: {completed}, 失败: {failed}, 总耗时: {time.time()-start_time:.2f}")
except (ConnectionError, TimeoutError) as e:
log_failure(f"用户交互过程中发生网络异常: {str(e)}")
log_failure(f"网络异常: {str(e)}")
except Exception as e:
log_failure(f"用户交互过程中发生异常: {str(e)}")
log_failure(f"操作异常: {str(e)}")
except KeyboardInterrupt:
log_warning("用户中断程序")
print("程序已中断")
except Exception as e:
log_failure(f"程序运行异常: {str(e)}")
def copy_data_to_target_directory(self):
"""拷贝D:\gp_data\day目录下的数据到D:\gp_data\floor_cl\data目录"""
import os
import shutil
import time
# 使用Config.BASE_DIR作为基础路径保持代码一致性
source_dir = os.path.join(Config.BASE_DIR, 'day')
target_dir = 'D:\\data\\jq_hc\\data'
log_process(f"准备从 {source_dir} 拷贝数据到 {target_dir}")
try:
# 确保目标目录存在
if not os.path.exists(target_dir):
os.makedirs(target_dir, exist_ok=True)
log_message(f"已创建目标目录: {target_dir}")
# 检查源目录是否存在
if not os.path.exists(source_dir):
log_failure(f"源目录 {source_dir} 不存在,请检查路径")
return
# 获取源目录中的txt文件列表
files = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f)) and f.lower().endswith('.txt')]
total_files = len(files)
if total_files == 0:
log_warning("源目录中没有找到文件")
return
log_process(f"开始拷贝 {total_files} 个文件...")
start_time = time.time()
copied_count = 0
failed_count = 0
for i, file_name in enumerate(files, 1):
source_path = os.path.join(source_dir, file_name)
target_path = os.path.join(target_dir, file_name)
try:
# 拷贝文件并覆盖已存在的文件
shutil.copy2(source_path, target_path)
copied_count += 1
# 显示进度
progress = (i / total_files) * 100
elapsed = time.time() - start_time
log_process(f"\r进度: [{('#' * int(progress / 2)):50}] {progress:.1f}% | {file_name}", end='', flush=True)
except Exception as e:
failed_count += 1
log_failure(f"\n拷贝文件 {file_name} 失败: {str(e)}")
# 打印完成信息
log_success(f"\n拷贝完成!")
log_result(f"成功: {copied_count}, 失败: {failed_count}")
log_result(f"总耗时: {time.time() - start_time:.2f}")
except KeyboardInterrupt:
log_warning("\n拷贝操作已被用户中断")
except Exception as e:
log_failure(f"\n拷贝过程中发生异常: {str(e)}")
def check_data_integrity(self):
"""检查数据完整性运行check_market_data.py程序"""
import subprocess
import os
try:
log_process("开始检查数据完整性...")
# 构建check_market_data.py的完整路径
script_path = os.path.join(os.path.dirname(__file__), 'check_market_data.py')
# 运行check_market_data.py程序直接显示输出
result = subprocess.run(
[sys.executable, script_path],
stdout=None, # 不捕获标准输出,直接显示在控制台上
stderr=None, # 不捕获错误输出,直接显示在控制台上
encoding='utf-8'
)
if result.returncode == 0:
log_success(f"\n数据完整性检查完成!")
else:
log_failure(f"\n数据完整性检查失败,返回码: {result.returncode}")
except FileNotFoundError:
log_failure(f"\n未找到check_market_data.py文件请检查路径: {script_path}")
except Exception as e:
log_failure(f"\n执行数据完整性检查时发生异常: {e}")
def process_stock_codes(self):
"""处理所有股票代码 - 优化版本"""
try:
with open(Config.INPUT_FILE, 'r', encoding='utf-8') as f:
stock_codes = [line.strip().split('\t')[0] for line in f if line.strip()]
except Exception as e:
log_failure(f"读取股票代码文件失败: {e}")
return
total = len(stock_codes)
completed = 0
start_time = time.time()
def print_progress():
progress = (completed / total) * 100
elapsed = time.time() - start_time
# 单行进度显示
log_process(
f"\r进度: [{'#' * int(progress / 2)}{' ' * (50 - int(progress / 2))}] {progress:.1f}% | 已完成: {completed}/{total} | 耗时: {elapsed:.1f}s",
end='', flush=True)
# 按批次处理,避免一次性创建过多线程
batch_size = Config.MAX_THREADS * 2
for i in range(0, total, batch_size):
batch = stock_codes[i:i + batch_size]
with ThreadPoolExecutor(max_workers=Config.MAX_THREADS) as executor:
futures = {
executor.submit(self.downloader.process_stock_code, code, None): code
for code in batch
}
for future in as_completed(futures):
completed += 1
if completed % 5 == 0 or completed == total:
print_progress()
future.result()
log_success(f"\r进度: [{('#' * 50)}] 100.0% | 已完成: {total}/{total} | 总耗时: {time.time() - start_time:.1f}s")
log_success(f"\n个股数据下载完成!")
# 创建全局AccountManager实例
_global_account_manager = AccountManager()
@lru_cache(maxsize=1000)
def get_cached_data(code):
"""缓存股票数据查询结果"""
pro = _global_account_manager.get_next_account() # 使用全局AccountManager实例
return pro.daily(ts_code=code)
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='Tushare数据更新工具')
parser.add_argument('-c', '--choice', type=str, choices=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
help='要执行的操作: 0=退出, 1=更新股票代码表, 2=更新指数数据, 3=更新个股数据, 4=全部更新, 5=同步数据库, 6=全部工作, 7=拷贝数据到目标目录, 8=检查数据完整性, 9=获取个股基础信息')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
downloader = ConsoleDataDownloader()
if args.choice:
# 使用命令行参数直接执行操作
downloader.run_with_choice(args.choice)
else:
# 进入交互式模式
downloader.run()