412 lines
14 KiB
Python
412 lines
14 KiB
Python
import os
|
||
from numba import types # 正确导入
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from tqdm import tqdm
|
||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||
import logging
|
||
from numba import jit
|
||
from datetime import datetime
|
||
|
||
# ========== 环境配置 ==========
|
||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
||
|
||
# ========== 策略参数配置 ==========
|
||
# ========== 全局参数 ==========
|
||
ATR_WINDOW = 5
|
||
VOLATILITY_WINDOW = 10
|
||
BULL_THRESHOLD = 0.83
|
||
BEAR_THRESHOLD = 0.77
|
||
NEUTRAL_THRESHOLD = 0.81
|
||
HOLDING_DAYS_MAP = {
|
||
'bull': 4,
|
||
'bear': 2,
|
||
'neutral': 3
|
||
}
|
||
|
||
|
||
|
||
# ========== 核心策略逻辑 ==========
|
||
@jit(nopython=True)
|
||
def calculate_technical_indicators(close, high, low, volume,
|
||
atr_window, volatility_window):
|
||
"""支持Numba的参数传递"""
|
||
n = len(close)
|
||
macd = np.zeros(n)
|
||
signal = np.zeros(n)
|
||
atr = np.zeros(n)
|
||
|
||
# MACD计算
|
||
ema12, ema26 = np.zeros(n), np.zeros(n)
|
||
for i in range(1, n):
|
||
ema12[i] = ema12[i - 1] * 11 / 13 + close[i] * 2 / 13
|
||
ema26[i] = ema26[i - 1] * 25 / 27 + close[i] * 2 / 27
|
||
macd[i] = ema12[i] - ema26[i]
|
||
signal[i] = signal[i - 1] * 0.8 + macd[i] * 0.2
|
||
|
||
# ATR计算
|
||
for i in range(1, n):
|
||
tr = max(high[i] - low[i],
|
||
abs(high[i] - close[i - 1]),
|
||
abs(low[i] - close[i - 1]))
|
||
atr[i] = atr[i-1] * (atr_window-1)/atr_window + tr/atr_window
|
||
|
||
return macd, signal, atr
|
||
|
||
|
||
@jit(nopython=True)
|
||
def generate_trading_signals(close, open_, high, low, volume,
|
||
macd, signal, atr,
|
||
threshold, volatility_window):
|
||
# 添加类型断言确保参数类型正确
|
||
n = len(close)
|
||
signals = np.zeros(n, dtype=np.bool_)
|
||
|
||
for i in range(3, n):
|
||
# 新增开盘价校验
|
||
if open_[i] <= 0 or close[i - 1] <= 0:
|
||
continue
|
||
# 基础K线形态条件
|
||
is_red = close[i] > open_[i]
|
||
|
||
# ========== 基础K线形态条件 ==========
|
||
# 条件1: 阳线且最高价超过前收盘1.005倍
|
||
cond1 = is_red and (high[i] / close[i - 1] > 1.005)
|
||
|
||
# 条件2: 实体大于上下影线
|
||
upper_shadow = high[i] - max(close[i], open_[i])
|
||
lower_shadow = min(close[i], open_[i]) - low[i]
|
||
body_size = abs(close[i] - open_[i])
|
||
cond2 = (body_size > upper_shadow) and (body_size > lower_shadow)
|
||
|
||
# 条件3: 高低价比值<1.12且最高开盘比>1.036
|
||
if open_[i] <= 0: # 二次校验开盘价
|
||
cond3 = False
|
||
else:
|
||
cond3 = (high[i] / low[i] < 1.12) and (high[i] / open_[i] > 1.036)
|
||
|
||
# 条件4: 排除涨停
|
||
cond4 = close[i] < close[i - 1] * 1.10
|
||
|
||
# ========== 技术指标条件 ==========
|
||
# 条件5: ATR超过近期均值80%
|
||
cond5 = atr[i] > np.mean(atr[max(0, i - 4):i + 1]) * 0.8
|
||
|
||
# 条件6: MACD动量增强
|
||
cond6 = (macd[i] - signal[i]) > (macd[i - 1] - signal[i - 1]) * 1.2
|
||
|
||
# ========== 波动率条件 ==========
|
||
# 新增分母校验
|
||
window_start = max(0, i - volatility_window + 1)
|
||
hhv = np.max(high[window_start:i + 1])
|
||
if hhv <= 0: # 确保最高价>0
|
||
cond7 = False
|
||
else:
|
||
llv = np.min(low[window_start:i + 1])
|
||
cond7 = (llv / hhv) < threshold
|
||
|
||
# ========== 量能条件 ==========
|
||
# 条件8: 成交量低于近期均值
|
||
vol_cond1 = volume[i] < np.mean(volume[max(0, i - 10):i])
|
||
|
||
# 条件9: 成交量低于近期最低值的3.5倍
|
||
vol_window = volume[max(0, i - 20):i - 1]
|
||
if len(vol_window) == 0:
|
||
vol_cond2 = False
|
||
else:
|
||
vol_min = np.min(vol_window)
|
||
vol_cond2 = volume[i] < vol_min * 3.5
|
||
|
||
signals[i] = cond1 & cond2 & cond3 & cond4 & cond5 & cond6 & cond7 & vol_cond1 & vol_cond2
|
||
|
||
return signals
|
||
|
||
|
||
# ========== 市场状态判断 ==========
|
||
def get_market_condition(index_data):
|
||
"""动态判断市场状态"""
|
||
if len(index_data) < 60:
|
||
return 'neutral'
|
||
|
||
ma20 = index_data['close'].rolling(20).mean().iloc[-1]
|
||
ma60 = index_data['close'].rolling(60).mean().iloc[-1]
|
||
|
||
if pd.isna(ma20) or pd.isna(ma60):
|
||
return 'neutral'
|
||
|
||
if ma20 > ma60 * 1.05:
|
||
return 'bull'
|
||
elif ma20 < ma60 * 0.95:
|
||
return 'bear'
|
||
else:
|
||
return 'neutral'
|
||
|
||
|
||
# ========== 数据加载处理 ==========
|
||
def load_index_data(index_path):
|
||
"""加载并预处理指数数据"""
|
||
try:
|
||
# 自动检测日期列名
|
||
df = pd.read_csv(index_path, sep='\t', nrows=0)
|
||
date_col = 'date' if 'date' in df.columns else 'trade_date'
|
||
|
||
index_data = pd.read_csv(
|
||
index_path,
|
||
sep='\t',
|
||
usecols=[date_col, 'open', 'high', 'low', 'close', 'volume'],
|
||
parse_dates=[date_col],
|
||
date_parser=lambda x: pd.to_datetime(x, format='%Y%m%d')
|
||
)
|
||
# index_data.rename(columns={date_col: 'trade_date'}, inplace=True)
|
||
index_data.sort_values(date_col, inplace=True)
|
||
|
||
# 过滤2022年及以后的数据
|
||
index_data = index_data[index_data[date_col] >= pd.Timestamp('2022-01-01')]
|
||
|
||
logging.info(
|
||
f"指数数据加载成功,时间范围: {index_data['trade_date'].min().date()} 至 {index_data['trade_date'].max().date()}")
|
||
return index_data
|
||
except Exception as e:
|
||
logging.error(f"指数数据加载失败: {str(e)}")
|
||
return None
|
||
|
||
|
||
def process_stock_file(file_path, index_data):
|
||
try:
|
||
# 增加列存在性检查
|
||
required_cols = ['trade_date', 'open', 'high', 'low', 'close', 'vol']
|
||
df = pd.read_csv(file_path, sep='\t', usecols=required_cols)
|
||
|
||
# 严格数据过滤(新增多条件校验)
|
||
df = df[
|
||
(df['open'] > 0) &
|
||
(df['close'] > 0) &
|
||
(df['high'] > 0) &
|
||
(df['low'] > 0) &
|
||
(df['high'] >= df['low']) & # 确保最高价>=最低价
|
||
(df['close'] >= df['low']) & # 确保收盘价>=最低价
|
||
(df['close'] <= df['high']) & # 确保收盘价<=最高价
|
||
(df['vol'] > 0) # 新增成交量校验
|
||
]
|
||
# 如果过滤后无数据则跳过
|
||
if df.empty:
|
||
logging.warning(f"文件 {os.path.basename(file_path)} 无有效数据")
|
||
return None
|
||
|
||
df = df.rename(columns={'vol': 'volume'})
|
||
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d', errors='coerce')
|
||
df = df.dropna(subset=['trade_date']).sort_values('trade_date')
|
||
|
||
# 对齐指数时间范围
|
||
start_date = index_data['trade_date'].min()
|
||
end_date = index_data['trade_date'].max()
|
||
df = df[(df['trade_date'] >= start_date) & (df['trade_date'] <= end_date)]
|
||
|
||
# 过滤2022年及以后的数据
|
||
df = df[df['trade_date'] >= pd.Timestamp('2022-01-01')]
|
||
|
||
# if len(df) < StrategyConfig.MIN_TRADE_DAYS:
|
||
# return None
|
||
|
||
# 计算技术指标
|
||
close = df['close'].values.astype(np.float64)
|
||
high = df['high'].values.astype(np.float64)
|
||
low = df['low'].values.astype(np.float64)
|
||
volume = df['volume'].values.astype(np.float64)
|
||
|
||
macd, signal, atr = calculate_technical_indicators(
|
||
close, high, low, volume,
|
||
atr_window=ATR_WINDOW,
|
||
volatility_window=VOLATILITY_WINDOW
|
||
)
|
||
|
||
# 获取市场状态和对应阈值
|
||
market_condition = get_market_condition(index_data)
|
||
threshold_map = {
|
||
'bull': BULL_THRESHOLD,
|
||
'bear': BEAR_THRESHOLD,
|
||
'neutral': NEUTRAL_THRESHOLD
|
||
}
|
||
threshold = threshold_map.get(market_condition, NEUTRAL_THRESHOLD)
|
||
|
||
# 生成信号时使用动态阈值
|
||
signals = generate_trading_signals(
|
||
close, df['open'].values, high, low, volume,
|
||
macd, signal, atr,
|
||
threshold=threshold, # 使用动态阈值
|
||
volatility_window=VOLATILITY_WINDOW
|
||
)
|
||
|
||
df['signal'] = signals
|
||
# 获取股票代码
|
||
stock_code = os.path.basename(file_path).split('_')[0]
|
||
return stock_code, df
|
||
|
||
except Exception as e:
|
||
logging.error(f"处理文件 {os.path.basename(file_path)} 失败: {str(e)}")
|
||
return None
|
||
|
||
|
||
# ========== 回测分析模块 ==========
|
||
def backtest_strategy(all_data, index_data):
|
||
"""执行动态持仓周期回测"""
|
||
results = []
|
||
|
||
|
||
def backtest_strategy(all_data, index_data):
|
||
results = []
|
||
for stock_code, data in all_data.items():
|
||
if data is None or 'signal' not in data.columns:
|
||
continue
|
||
try:
|
||
signals = data[data['signal']]
|
||
for idx in signals.index:
|
||
# 获取关键价格数据
|
||
entry = data.iloc[idx]
|
||
next_day_data = data.iloc[idx + 1] if idx + 1 < len(data) else None
|
||
|
||
# 严格校验数据有效性(新增校验点)
|
||
if next_day_data is None or \
|
||
entry['close'] <= 0 or \
|
||
next_day_data['open'] <= 0:
|
||
continue
|
||
|
||
# 动态持仓计算
|
||
current_date = entry['trade_date']
|
||
historical_index = index_data[index_data['trade_date'] <= current_date]
|
||
market_condition = get_market_condition(historical_index)
|
||
holding_days = HOLDING_DAYS_MAP.get(market_condition, 2)
|
||
|
||
# 计算收益区间(新增边界校验)
|
||
exit_idx = idx + holding_days + 1
|
||
if exit_idx >= len(data):
|
||
exit_idx = len(data) - 1
|
||
exit_data = data.iloc[idx + 1:exit_idx + 1]
|
||
|
||
# 计算收益率(增加try-except保护)
|
||
try:
|
||
entry_price = entry['close']
|
||
exit_prices = exit_data['close']
|
||
|
||
max_profit = (exit_prices.max() - entry_price) / entry_price
|
||
max_loss = (exit_prices.min() - entry_price) / entry_price
|
||
final_return = (exit_prices.iloc[-1] - entry_price) / entry_price
|
||
|
||
# 格式化保留小数点后4位
|
||
final_return = round(final_return, 4)
|
||
max_profit = round(max_profit, 4)
|
||
max_loss = round(max_loss, 4)
|
||
|
||
except ZeroDivisionError:
|
||
logging.error(f"零除错误 @ {stock_code} {current_date}")
|
||
continue
|
||
results.append({
|
||
'code': stock_code,
|
||
'date': current_date.strftime('%Y-%m-%d'),
|
||
'market': market_condition,
|
||
'holding_days': holding_days,
|
||
'return': final_return,
|
||
'max_profit': max_profit,
|
||
'max_loss': max_loss
|
||
})
|
||
except Exception as e:
|
||
logging.error(f"处理股票 {stock_code} 时出错: {str(e)}")
|
||
return pd.DataFrame(results)
|
||
|
||
|
||
def analyze_results(results_df):
|
||
"""分析回测结果"""
|
||
if results_df.empty:
|
||
logging.warning("无有效交易记录")
|
||
return
|
||
|
||
# 基础统计
|
||
total_trades = len(results_df)
|
||
annual_return = results_df['return'].mean() * 252
|
||
win_rate = len(results_df[results_df['return'] > 0]) / total_trades
|
||
profit_factor = results_df[results_df['return'] > 0]['return'].mean() / \
|
||
abs(results_df[results_df['return'] < 0]['return'].mean())
|
||
|
||
print(f"\n策略表现汇总:")
|
||
print(f"总交易次数: {total_trades}")
|
||
print(f"年化收益率: {annual_return:.2%}")
|
||
print(f"胜率: {win_rate:.2%}")
|
||
print(f"盈亏比: {profit_factor:.2f}")
|
||
|
||
# 分市场状态分析
|
||
if 'market' in results_df.columns:
|
||
market_stats = results_df.groupby('market').agg({
|
||
'return': ['mean', 'count'],
|
||
'holding_days': 'mean'
|
||
})
|
||
print("\n分市场状态表现:")
|
||
print(market_stats)
|
||
|
||
# 可视化
|
||
plt.figure(figsize=(12, 5))
|
||
plt.subplot(121)
|
||
results_df['return'].hist(bins=20, alpha=0.7)
|
||
plt.title('收益率分布')
|
||
plt.xlabel('收益率')
|
||
plt.ylabel('频次')
|
||
|
||
plt.subplot(122)
|
||
if 'market' in results_df.columns:
|
||
for condition, group in results_df.groupby('market'):
|
||
plt.scatter(group['holding_days'], group['return'], alpha=0.5, label=condition)
|
||
plt.legend()
|
||
plt.axhline(0, color='red', linestyle='--')
|
||
plt.title('持仓周期 vs 收益率')
|
||
plt.xlabel('持仓天数')
|
||
plt.ylabel('收益率')
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
|
||
# ========== 主程序 ==========
|
||
if __name__ == "__main__":
|
||
# 配置路径
|
||
STOCK_DIR = 'day/'
|
||
INDEX_PATH = 'index/000001.SH.txt'
|
||
|
||
# 加载指数数据
|
||
logging.info("正在加载指数数据...")
|
||
index_data = load_index_data(INDEX_PATH)
|
||
if index_data is None:
|
||
exit()
|
||
|
||
# 并行处理个股数据
|
||
logging.info("正在加载个股数据...")
|
||
stock_files = [os.path.join(STOCK_DIR, f) for f in os.listdir(STOCK_DIR)
|
||
if f.endswith('.txt') and not any(kw in f for kw in ['ST', '*ST', '688'])]
|
||
|
||
all_data = {}
|
||
with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
|
||
futures = {executor.submit(process_stock_file, f, index_data): f for f in stock_files}
|
||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||
result = future.result()
|
||
if result:
|
||
code, data = result
|
||
all_data[code] = data
|
||
|
||
if not all_data:
|
||
logging.error("没有加载到有效股票数据")
|
||
exit()
|
||
|
||
# 执行回测
|
||
logging.info("开始回测...")
|
||
results_df = backtest_strategy(all_data, index_data)
|
||
|
||
# 分析结果
|
||
analyze_results(results_df)
|
||
|
||
# 保存结果
|
||
if not results_df.empty:
|
||
results_df.to_csv('strategy_backtest_results.csv', index=False)
|
||
logging.info("回测结果已保存至 strategy_backtest_results.csv") |