import os from numba import types # 正确导入 import pandas as pd import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor, as_completed import logging from numba import jit from datetime import datetime # ========== 环境配置 ========== plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # ========== 策略参数配置 ========== # ========== 全局参数 ========== ATR_WINDOW = 5 VOLATILITY_WINDOW = 10 BULL_THRESHOLD = 0.83 BEAR_THRESHOLD = 0.77 NEUTRAL_THRESHOLD = 0.81 HOLDING_DAYS_MAP = { 'bull': 4, 'bear': 2, 'neutral': 3 } # ========== 核心策略逻辑 ========== @jit(nopython=True) def calculate_technical_indicators(close, high, low, volume, atr_window, volatility_window): """支持Numba的参数传递""" n = len(close) macd = np.zeros(n) signal = np.zeros(n) atr = np.zeros(n) # MACD计算 ema12, ema26 = np.zeros(n), np.zeros(n) for i in range(1, n): ema12[i] = ema12[i - 1] * 11 / 13 + close[i] * 2 / 13 ema26[i] = ema26[i - 1] * 25 / 27 + close[i] * 2 / 27 macd[i] = ema12[i] - ema26[i] signal[i] = signal[i - 1] * 0.8 + macd[i] * 0.2 # ATR计算 for i in range(1, n): tr = max(high[i] - low[i], abs(high[i] - close[i - 1]), abs(low[i] - close[i - 1])) atr[i] = atr[i-1] * (atr_window-1)/atr_window + tr/atr_window return macd, signal, atr @jit(nopython=True) def generate_trading_signals(close, open_, high, low, volume, macd, signal, atr, threshold, volatility_window): # 添加类型断言确保参数类型正确 n = len(close) signals = np.zeros(n, dtype=np.bool_) for i in range(3, n): # 新增开盘价校验 if open_[i] <= 0 or close[i - 1] <= 0: continue # 基础K线形态条件 is_red = close[i] > open_[i] # ========== 基础K线形态条件 ========== # 条件1: 阳线且最高价超过前收盘1.005倍 cond1 = is_red and (high[i] / close[i - 1] > 1.005) # 条件2: 实体大于上下影线 upper_shadow = high[i] - max(close[i], open_[i]) lower_shadow = min(close[i], open_[i]) - low[i] body_size = abs(close[i] - open_[i]) cond2 = (body_size > upper_shadow) and (body_size > lower_shadow) # 条件3: 高低价比值<1.12且最高开盘比>1.036 if open_[i] <= 0: # 二次校验开盘价 cond3 = False else: cond3 = (high[i] / low[i] < 1.12) and (high[i] / open_[i] > 1.036) # 条件4: 排除涨停 cond4 = close[i] < close[i - 1] * 1.10 # ========== 技术指标条件 ========== # 条件5: ATR超过近期均值80% cond5 = atr[i] > np.mean(atr[max(0, i - 4):i + 1]) * 0.8 # 条件6: MACD动量增强 cond6 = (macd[i] - signal[i]) > (macd[i - 1] - signal[i - 1]) * 1.2 # ========== 波动率条件 ========== # 新增分母校验 window_start = max(0, i - volatility_window + 1) hhv = np.max(high[window_start:i + 1]) if hhv <= 0: # 确保最高价>0 cond7 = False else: llv = np.min(low[window_start:i + 1]) cond7 = (llv / hhv) < threshold # ========== 量能条件 ========== # 条件8: 成交量低于近期均值 vol_cond1 = volume[i] < np.mean(volume[max(0, i - 10):i]) # 条件9: 成交量低于近期最低值的3.5倍 vol_window = volume[max(0, i - 20):i - 1] if len(vol_window) == 0: vol_cond2 = False else: vol_min = np.min(vol_window) vol_cond2 = volume[i] < vol_min * 3.5 signals[i] = cond1 & cond2 & cond3 & cond4 & cond5 & cond6 & cond7 & vol_cond1 & vol_cond2 return signals # ========== 市场状态判断 ========== def get_market_condition(index_data): """动态判断市场状态""" if len(index_data) < 60: return 'neutral' ma20 = index_data['close'].rolling(20).mean().iloc[-1] ma60 = index_data['close'].rolling(60).mean().iloc[-1] if pd.isna(ma20) or pd.isna(ma60): return 'neutral' if ma20 > ma60 * 1.05: return 'bull' elif ma20 < ma60 * 0.95: return 'bear' else: return 'neutral' # ========== 数据加载处理 ========== def load_index_data(index_path): """加载并预处理指数数据""" try: # 自动检测日期列名 df = pd.read_csv(index_path, sep='\t', nrows=0) date_col = 'date' if 'date' in df.columns else 'trade_date' index_data = pd.read_csv( index_path, sep='\t', usecols=[date_col, 'open', 'high', 'low', 'close', 'volume'], parse_dates=[date_col], date_parser=lambda x: pd.to_datetime(x, format='%Y%m%d') ) # index_data.rename(columns={date_col: 'trade_date'}, inplace=True) index_data.sort_values(date_col, inplace=True) # 过滤2022年及以后的数据 index_data = index_data[index_data[date_col] >= pd.Timestamp('2022-01-01')] logging.info( f"指数数据加载成功,时间范围: {index_data['trade_date'].min().date()} 至 {index_data['trade_date'].max().date()}") return index_data except Exception as e: logging.error(f"指数数据加载失败: {str(e)}") return None def process_stock_file(file_path, index_data): try: # 增加列存在性检查 required_cols = ['trade_date', 'open', 'high', 'low', 'close', 'vol'] df = pd.read_csv(file_path, sep='\t', usecols=required_cols) # 严格数据过滤(新增多条件校验) df = df[ (df['open'] > 0) & (df['close'] > 0) & (df['high'] > 0) & (df['low'] > 0) & (df['high'] >= df['low']) & # 确保最高价>=最低价 (df['close'] >= df['low']) & # 确保收盘价>=最低价 (df['close'] <= df['high']) & # 确保收盘价<=最高价 (df['vol'] > 0) # 新增成交量校验 ] # 如果过滤后无数据则跳过 if df.empty: logging.warning(f"文件 {os.path.basename(file_path)} 无有效数据") return None df = df.rename(columns={'vol': 'volume'}) df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d', errors='coerce') df = df.dropna(subset=['trade_date']).sort_values('trade_date') # 对齐指数时间范围 start_date = index_data['trade_date'].min() end_date = index_data['trade_date'].max() df = df[(df['trade_date'] >= start_date) & (df['trade_date'] <= end_date)] # 过滤2022年及以后的数据 df = df[df['trade_date'] >= pd.Timestamp('2022-01-01')] # if len(df) < StrategyConfig.MIN_TRADE_DAYS: # return None # 计算技术指标 close = df['close'].values.astype(np.float64) high = df['high'].values.astype(np.float64) low = df['low'].values.astype(np.float64) volume = df['volume'].values.astype(np.float64) macd, signal, atr = calculate_technical_indicators( close, high, low, volume, atr_window=ATR_WINDOW, volatility_window=VOLATILITY_WINDOW ) # 获取市场状态和对应阈值 market_condition = get_market_condition(index_data) threshold_map = { 'bull': BULL_THRESHOLD, 'bear': BEAR_THRESHOLD, 'neutral': NEUTRAL_THRESHOLD } threshold = threshold_map.get(market_condition, NEUTRAL_THRESHOLD) # 生成信号时使用动态阈值 signals = generate_trading_signals( close, df['open'].values, high, low, volume, macd, signal, atr, threshold=threshold, # 使用动态阈值 volatility_window=VOLATILITY_WINDOW ) df['signal'] = signals # 获取股票代码 stock_code = os.path.basename(file_path).split('_')[0] return stock_code, df except Exception as e: logging.error(f"处理文件 {os.path.basename(file_path)} 失败: {str(e)}") return None # ========== 回测分析模块 ========== def backtest_strategy(all_data, index_data): """执行动态持仓周期回测""" results = [] def backtest_strategy(all_data, index_data): results = [] for stock_code, data in all_data.items(): if data is None or 'signal' not in data.columns: continue try: signals = data[data['signal']] for idx in signals.index: # 获取关键价格数据 entry = data.iloc[idx] next_day_data = data.iloc[idx + 1] if idx + 1 < len(data) else None # 严格校验数据有效性(新增校验点) if next_day_data is None or \ entry['close'] <= 0 or \ next_day_data['open'] <= 0: continue # 动态持仓计算 current_date = entry['trade_date'] historical_index = index_data[index_data['trade_date'] <= current_date] market_condition = get_market_condition(historical_index) holding_days = HOLDING_DAYS_MAP.get(market_condition, 2) # 计算收益区间(新增边界校验) exit_idx = idx + holding_days + 1 if exit_idx >= len(data): exit_idx = len(data) - 1 exit_data = data.iloc[idx + 1:exit_idx + 1] # 计算收益率(增加try-except保护) try: entry_price = entry['close'] exit_prices = exit_data['close'] max_profit = (exit_prices.max() - entry_price) / entry_price max_loss = (exit_prices.min() - entry_price) / entry_price final_return = (exit_prices.iloc[-1] - entry_price) / entry_price # 格式化保留小数点后4位 final_return = round(final_return, 4) max_profit = round(max_profit, 4) max_loss = round(max_loss, 4) except ZeroDivisionError: logging.error(f"零除错误 @ {stock_code} {current_date}") continue results.append({ 'code': stock_code, 'date': current_date.strftime('%Y-%m-%d'), 'market': market_condition, 'holding_days': holding_days, 'return': final_return, 'max_profit': max_profit, 'max_loss': max_loss }) except Exception as e: logging.error(f"处理股票 {stock_code} 时出错: {str(e)}") return pd.DataFrame(results) def analyze_results(results_df): """分析回测结果""" if results_df.empty: logging.warning("无有效交易记录") return # 基础统计 total_trades = len(results_df) annual_return = results_df['return'].mean() * 252 win_rate = len(results_df[results_df['return'] > 0]) / total_trades profit_factor = results_df[results_df['return'] > 0]['return'].mean() / \ abs(results_df[results_df['return'] < 0]['return'].mean()) print(f"\n策略表现汇总:") print(f"总交易次数: {total_trades}") print(f"年化收益率: {annual_return:.2%}") print(f"胜率: {win_rate:.2%}") print(f"盈亏比: {profit_factor:.2f}") # 分市场状态分析 if 'market' in results_df.columns: market_stats = results_df.groupby('market').agg({ 'return': ['mean', 'count'], 'holding_days': 'mean' }) print("\n分市场状态表现:") print(market_stats) # 可视化 plt.figure(figsize=(12, 5)) plt.subplot(121) results_df['return'].hist(bins=20, alpha=0.7) plt.title('收益率分布') plt.xlabel('收益率') plt.ylabel('频次') plt.subplot(122) if 'market' in results_df.columns: for condition, group in results_df.groupby('market'): plt.scatter(group['holding_days'], group['return'], alpha=0.5, label=condition) plt.legend() plt.axhline(0, color='red', linestyle='--') plt.title('持仓周期 vs 收益率') plt.xlabel('持仓天数') plt.ylabel('收益率') plt.tight_layout() plt.show() # ========== 主程序 ========== if __name__ == "__main__": # 配置路径 STOCK_DIR = 'day/' INDEX_PATH = 'index/000001.SH.txt' # 加载指数数据 logging.info("正在加载指数数据...") index_data = load_index_data(INDEX_PATH) if index_data is None: exit() # 并行处理个股数据 logging.info("正在加载个股数据...") stock_files = [os.path.join(STOCK_DIR, f) for f in os.listdir(STOCK_DIR) if f.endswith('.txt') and not any(kw in f for kw in ['ST', '*ST', '688'])] all_data = {} with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor: futures = {executor.submit(process_stock_file, f, index_data): f for f in stock_files} for future in tqdm(as_completed(futures), total=len(futures)): result = future.result() if result: code, data = result all_data[code] = data if not all_data: logging.error("没有加载到有效股票数据") exit() # 执行回测 logging.info("开始回测...") results_df = backtest_strategy(all_data, index_data) # 分析结果 analyze_results(results_df) # 保存结果 if not results_df.empty: results_df.to_csv('strategy_backtest_results.csv', index=False) logging.info("回测结果已保存至 strategy_backtest_results.csv")