diff --git a/尾盘_数据统计_优化01.py b/尾盘_数据统计_优化01.py new file mode 100644 index 0000000..4bf9cc9 --- /dev/null +++ b/尾盘_数据统计_优化01.py @@ -0,0 +1,324 @@ +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed +import logging +from numba import jit +from datetime import datetime + +# ========== 环境配置 ========== +plt.rcParams['font.sans-serif'] = ['SimHei'] +plt.rcParams['axes.unicode_minus'] = False +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + +# ========== 策略参数配置 ========== +# ========== 全局参数 ========== +ATR_WINDOW = 5 +VOLATILITY_WINDOW = 10 +BULL_THRESHOLD = 0.83 +BEAR_THRESHOLD = 0.77 +NEUTRAL_THRESHOLD = 0.81 +HOLDING_DAYS_MAP = { + 'bull': 4, + 'bear': 2, + 'neutral': 3 +} + + + +# ========== 核心策略逻辑 ========== +@jit(nopython=True) +def calculate_technical_indicators(close, high, low, volume, + atr_window, volatility_window): + """支持Numba的参数传递""" + n = len(close) + macd = np.zeros(n) + signal = np.zeros(n) + atr = np.zeros(n) + + # MACD计算 + ema12, ema26 = np.zeros(n), np.zeros(n) + for i in range(1, n): + ema12[i] = ema12[i - 1] * 11 / 13 + close[i] * 2 / 13 + ema26[i] = ema26[i - 1] * 25 / 27 + close[i] * 2 / 27 + macd[i] = ema12[i] - ema26[i] + signal[i] = signal[i - 1] * 0.8 + macd[i] * 0.2 + + # ATR计算 + for i in range(1, n): + tr = max(high[i] - low[i], + abs(high[i] - close[i - 1]), + abs(low[i] - close[i - 1])) + atr[i] = atr[i-1] * (atr_window-1)/atr_window + tr/atr_window + + return macd, signal, atr + + +@jit(nopython=True) +def generate_trading_signals(close, open_, high, low, volume, macd, signal, atr, threshold, volatility_window): + """生成交易信号""" + n = len(close) + signals = np.zeros(n, dtype=np.bool_) + + for i in range(3, n): + # 基础K线形态条件 + is_red = close[i] > open_[i] + upper_shadow = high[i] - max(close[i], open_[i]) + lower_shadow = min(close[i], open_[i]) - low[i] + body_size = abs(close[i] - open_[i]) + + cond1 = is_red and (high[i] / close[i - 1] > 1.005) + cond2 = (body_size > upper_shadow) and (body_size > lower_shadow) + cond3 = (high[i] / low[i] < 1.12) and (high[i] / open_[i] > 1.036) + cond4 = close[i] < close[i - 1] * 1.10 # 排除涨停 + + # 技术指标条件 + cond5 = atr[i] > np.mean(atr[i - 4:i + 1]) * 0.8 + cond6 = (macd[i] - signal[i]) > (macd[i - 1] - signal[i - 1]) * 1.2 + + # 波动率条件 + llv = np.min(low[max(0, i - volatility_window + 1):i + 1]) + hhv = np.max(high[max(0, i - volatility_window + 1):i + 1]) + cond7 = (llv / hhv) < threshold + + # 量能条件 + vol_cond1 = volume[i] < np.mean(volume[max(0, i - 10):i]) + vol_cond2 = volume[i] < np.min(volume[max(0, i - 20):i - 1]) * 3.5 + + signals[i] = cond1 & cond2 & cond3 & cond4 & cond5 & cond6 & cond7 & vol_cond1 & vol_cond2 + + return signals + + +# ========== 市场状态判断 ========== +def get_market_condition(index_data): + """动态判断市场状态""" + if len(index_data) < 60: + return 'neutral' + + ma20 = index_data['close'].rolling(20).mean().iloc[-1] + ma60 = index_data['close'].rolling(60).mean().iloc[-1] + + if pd.isna(ma20) or pd.isna(ma60): + return 'neutral' + + if ma20 > ma60 * 1.05: + return 'bull' + elif ma20 < ma60 * 0.95: + return 'bear' + else: + return 'neutral' + + +# ========== 数据加载处理 ========== +def load_index_data(index_path): + """加载并预处理指数数据""" + try: + # 自动检测日期列名 + df = pd.read_csv(index_path, sep='\t', nrows=0) + date_col = 'date' if 'date' in df.columns else 'trade_date' + + index_data = pd.read_csv( + index_path, + sep='\t', + usecols=[date_col, 'open', 'high', 'low', 'close', 'volume'], + parse_dates=[date_col], + date_parser=lambda x: pd.to_datetime(x, format='%Y%m%d') + ) + # index_data.rename(columns={date_col: 'trade_date'}, inplace=True) + index_data.sort_values(date_col, inplace=True) + + logging.info( + f"指数数据加载成功,时间范围: {index_data['trade_date'].min().date()} 至 {index_data['trade_date'].max().date()}") + return index_data + except Exception as e: + logging.error(f"指数数据加载失败: {str(e)}") + return None + + +def process_stock_file(file_path, index_data): + """处理单个股票文件""" + try: + # 加载并预处理数据 + df = pd.read_csv(file_path, sep='\t', + usecols=['trade_date', 'open', 'high', 'low', 'close', 'vol']) + df = df.rename(columns={'vol': 'volume'}) + df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d', errors='coerce') + df = df.dropna(subset=['trade_date']).sort_values('trade_date') + + # 对齐指数时间范围 + start_date = index_data['trade_date'].min() + end_date = index_data['trade_date'].max() + df = df[(df['trade_date'] >= start_date) & (df['trade_date'] <= end_date)] + + # if len(df) < StrategyConfig.MIN_TRADE_DAYS: + # return None + + # 计算技术指标 + close = df['close'].values.astype(np.float64) + high = df['high'].values.astype(np.float64) + low = df['low'].values.astype(np.float64) + volume = df['volume'].values.astype(np.float64) + + macd, signal, atr = calculate_technical_indicators( + close, high, low, volume, + atr_window=ATR_WINDOW, + volatility_window=VOLATILITY_WINDOW + ) + + # # 获取市场状态 + # market_condition = get_market_condition(index_data) + # threshold = StrategyConfig.THRESHOLDS[market_condition] + + # 生成信号 + signals = generate_trading_signals( + close, df['open'].values, high, low, volume, + macd, signal, atr, threshold=BULL_THRESHOLD, + volatility_window=VOLATILITY_WINDOW + ) + + df['signal'] = signals + return os.path.basename(file_path).split('_')[0], df + + except Exception as e: + logging.error(f"处理文件 {os.path.basename(file_path)} 失败: {str(e)}") + return None + + +# ========== 回测分析模块 ========== +def backtest_strategy(all_data, index_data): + """执行动态持仓周期回测""" + results = [] + + for stock_code, data in all_data.items(): + if data is None or 'signal' not in data.columns: + continue + + signals = data[data['signal']] + for idx in signals.index: + # 动态获取市场状态 + current_date = data.iloc[idx]['trade_date'] + market_condition = get_market_condition(index_data) + holding_days = HOLDING_DAYS_MAP.get(market_condition, 2) + + # 计算退出时间 + exit_idx = idx + holding_days + 1 # 包含买入当天 + + if exit_idx >= len(data): + continue + + # 计算收益 + entry_price = data.loc[idx, 'close'] + exit_prices = data.iloc[idx + 1:exit_idx]['close'] + + max_profit = (exit_prices.max() - entry_price) / entry_price + max_loss = (exit_prices.min() - entry_price) / entry_price + final_return = (exit_prices.iloc[-1] - entry_price) / entry_price + + results.append({ + 'code': stock_code, + 'date': current_date.strftime('%Y-%m-%d'), + 'market': market_condition, + 'holding_days': holding_days, + 'return': final_return, + 'max_profit': max_profit, + 'max_loss': max_loss + }) + + return pd.DataFrame(results) if results else pd.DataFrame() + + +def analyze_results(results_df): + """分析回测结果""" + if results_df.empty: + logging.warning("无有效交易记录") + return + + # 基础统计 + total_trades = len(results_df) + annual_return = results_df['return'].mean() * 252 + win_rate = len(results_df[results_df['return'] > 0]) / total_trades + profit_factor = results_df[results_df['return'] > 0]['return'].mean() / \ + abs(results_df[results_df['return'] < 0]['return'].mean()) + + print(f"\n策略表现汇总:") + print(f"总交易次数: {total_trades}") + print(f"年化收益率: {annual_return:.2%}") + print(f"胜率: {win_rate:.2%}") + print(f"盈亏比: {profit_factor:.2f}") + + # 分市场状态分析 + if 'market' in results_df.columns: + market_stats = results_df.groupby('market').agg({ + 'return': ['mean', 'count'], + 'holding_days': 'mean' + }) + print("\n分市场状态表现:") + print(market_stats) + + # 可视化 + plt.figure(figsize=(12, 5)) + plt.subplot(121) + results_df['return'].hist(bins=20, alpha=0.7) + plt.title('收益率分布') + plt.xlabel('收益率') + plt.ylabel('频次') + + plt.subplot(122) + if 'market' in results_df.columns: + for condition, group in results_df.groupby('market'): + plt.scatter(group['holding_days'], group['return'], alpha=0.5, label=condition) + plt.legend() + plt.axhline(0, color='red', linestyle='--') + plt.title('持仓周期 vs 收益率') + plt.xlabel('持仓天数') + plt.ylabel('收益率') + + plt.tight_layout() + plt.show() + + +# ========== 主程序 ========== +if __name__ == "__main__": + # 配置路径 + STOCK_DIR = '/day/' + INDEX_PATH = '/index/000001.SH.txt' + + # 加载指数数据 + logging.info("正在加载指数数据...") + index_data = load_index_data(INDEX_PATH) + if index_data is None: + exit() + + # 并行处理个股数据 + logging.info("正在加载个股数据...") + stock_files = [os.path.join(STOCK_DIR, f) for f in os.listdir(STOCK_DIR) + if f.endswith('.txt') and not any(kw in f for kw in ['ST', '*ST', '688'])] + + all_data = {} + with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor: + futures = {executor.submit(process_stock_file, f, index_data): f for f in stock_files} + for future in tqdm(as_completed(futures), total=len(futures)): + result = future.result() + if result: + code, data, _ = result + all_data[code] = data + + if not all_data: + logging.error("没有加载到有效股票数据") + exit() + + # 执行回测 + logging.info("开始回测...") + results_df = backtest_strategy(all_data, index_data) + + # 分析结果 + analyze_results(results_df) + + # 保存结果 + if not results_df.empty: + results_df.to_csv('strategy_backtest_results.csv', index=False) + logging.info("回测结果已保存至 strategy_backtest_results.csv") \ No newline at end of file