Files
weipan_cl/尾盘_数据统计_优化01.py

412 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from numba import types # 正确导入
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import logging
from numba import jit
from datetime import datetime
# ========== 环境配置 ==========
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# ========== 策略参数配置 ==========
# ========== 全局参数 ==========
ATR_WINDOW = 5
VOLATILITY_WINDOW = 10
BULL_THRESHOLD = 0.83
BEAR_THRESHOLD = 0.77
NEUTRAL_THRESHOLD = 0.81
HOLDING_DAYS_MAP = {
'bull': 4,
'bear': 2,
'neutral': 3
}
# ========== 核心策略逻辑 ==========
@jit(nopython=True)
def calculate_technical_indicators(close, high, low, volume,
atr_window, volatility_window):
"""支持Numba的参数传递"""
n = len(close)
macd = np.zeros(n)
signal = np.zeros(n)
atr = np.zeros(n)
# MACD计算
ema12, ema26 = np.zeros(n), np.zeros(n)
for i in range(1, n):
ema12[i] = ema12[i - 1] * 11 / 13 + close[i] * 2 / 13
ema26[i] = ema26[i - 1] * 25 / 27 + close[i] * 2 / 27
macd[i] = ema12[i] - ema26[i]
signal[i] = signal[i - 1] * 0.8 + macd[i] * 0.2
# ATR计算
for i in range(1, n):
tr = max(high[i] - low[i],
abs(high[i] - close[i - 1]),
abs(low[i] - close[i - 1]))
atr[i] = atr[i-1] * (atr_window-1)/atr_window + tr/atr_window
return macd, signal, atr
@jit(nopython=True)
def generate_trading_signals(close, open_, high, low, volume,
macd, signal, atr,
threshold, volatility_window):
# 添加类型断言确保参数类型正确
n = len(close)
signals = np.zeros(n, dtype=np.bool_)
for i in range(3, n):
# 新增开盘价校验
if open_[i] <= 0 or close[i - 1] <= 0:
continue
# 基础K线形态条件
is_red = close[i] > open_[i]
# ========== 基础K线形态条件 ==========
# 条件1: 阳线且最高价超过前收盘1.005倍
cond1 = is_red and (high[i] / close[i - 1] > 1.005)
# 条件2: 实体大于上下影线
upper_shadow = high[i] - max(close[i], open_[i])
lower_shadow = min(close[i], open_[i]) - low[i]
body_size = abs(close[i] - open_[i])
cond2 = (body_size > upper_shadow) and (body_size > lower_shadow)
# 条件3: 高低价比值<1.12且最高开盘比>1.036
if open_[i] <= 0: # 二次校验开盘价
cond3 = False
else:
cond3 = (high[i] / low[i] < 1.12) and (high[i] / open_[i] > 1.036)
# 条件4: 排除涨停
cond4 = close[i] < close[i - 1] * 1.10
# ========== 技术指标条件 ==========
# 条件5: ATR超过近期均值80%
cond5 = atr[i] > np.mean(atr[max(0, i - 4):i + 1]) * 0.8
# 条件6: MACD动量增强
cond6 = (macd[i] - signal[i]) > (macd[i - 1] - signal[i - 1]) * 1.2
# ========== 波动率条件 ==========
# 新增分母校验
window_start = max(0, i - volatility_window + 1)
hhv = np.max(high[window_start:i + 1])
if hhv <= 0: # 确保最高价>0
cond7 = False
else:
llv = np.min(low[window_start:i + 1])
cond7 = (llv / hhv) < threshold
# ========== 量能条件 ==========
# 条件8: 成交量低于近期均值
vol_cond1 = volume[i] < np.mean(volume[max(0, i - 10):i])
# 条件9: 成交量低于近期最低值的3.5倍
vol_window = volume[max(0, i - 20):i - 1]
if len(vol_window) == 0:
vol_cond2 = False
else:
vol_min = np.min(vol_window)
vol_cond2 = volume[i] < vol_min * 3.5
signals[i] = cond1 & cond2 & cond3 & cond4 & cond5 & cond6 & cond7 & vol_cond1 & vol_cond2
return signals
# ========== 市场状态判断 ==========
def get_market_condition(index_data):
"""动态判断市场状态"""
if len(index_data) < 60:
return 'neutral'
ma20 = index_data['close'].rolling(20).mean().iloc[-1]
ma60 = index_data['close'].rolling(60).mean().iloc[-1]
if pd.isna(ma20) or pd.isna(ma60):
return 'neutral'
if ma20 > ma60 * 1.05:
return 'bull'
elif ma20 < ma60 * 0.95:
return 'bear'
else:
return 'neutral'
# ========== 数据加载处理 ==========
def load_index_data(index_path):
"""加载并预处理指数数据"""
try:
# 自动检测日期列名
df = pd.read_csv(index_path, sep='\t', nrows=0)
date_col = 'date' if 'date' in df.columns else 'trade_date'
index_data = pd.read_csv(
index_path,
sep='\t',
usecols=[date_col, 'open', 'high', 'low', 'close', 'volume'],
parse_dates=[date_col],
date_parser=lambda x: pd.to_datetime(x, format='%Y%m%d')
)
# index_data.rename(columns={date_col: 'trade_date'}, inplace=True)
index_data.sort_values(date_col, inplace=True)
# 过滤2022年及以后的数据
index_data = index_data[index_data[date_col] >= pd.Timestamp('2022-01-01')]
logging.info(
f"指数数据加载成功,时间范围: {index_data['trade_date'].min().date()}{index_data['trade_date'].max().date()}")
return index_data
except Exception as e:
logging.error(f"指数数据加载失败: {str(e)}")
return None
def process_stock_file(file_path, index_data):
try:
# 增加列存在性检查
required_cols = ['trade_date', 'open', 'high', 'low', 'close', 'vol']
df = pd.read_csv(file_path, sep='\t', usecols=required_cols)
# 严格数据过滤(新增多条件校验)
df = df[
(df['open'] > 0) &
(df['close'] > 0) &
(df['high'] > 0) &
(df['low'] > 0) &
(df['high'] >= df['low']) & # 确保最高价>=最低价
(df['close'] >= df['low']) & # 确保收盘价>=最低价
(df['close'] <= df['high']) & # 确保收盘价<=最高价
(df['vol'] > 0) # 新增成交量校验
]
# 如果过滤后无数据则跳过
if df.empty:
logging.warning(f"文件 {os.path.basename(file_path)} 无有效数据")
return None
df = df.rename(columns={'vol': 'volume'})
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d', errors='coerce')
df = df.dropna(subset=['trade_date']).sort_values('trade_date')
# 对齐指数时间范围
start_date = index_data['trade_date'].min()
end_date = index_data['trade_date'].max()
df = df[(df['trade_date'] >= start_date) & (df['trade_date'] <= end_date)]
# 过滤2022年及以后的数据
df = df[df['trade_date'] >= pd.Timestamp('2022-01-01')]
# if len(df) < StrategyConfig.MIN_TRADE_DAYS:
# return None
# 计算技术指标
close = df['close'].values.astype(np.float64)
high = df['high'].values.astype(np.float64)
low = df['low'].values.astype(np.float64)
volume = df['volume'].values.astype(np.float64)
macd, signal, atr = calculate_technical_indicators(
close, high, low, volume,
atr_window=ATR_WINDOW,
volatility_window=VOLATILITY_WINDOW
)
# 获取市场状态和对应阈值
market_condition = get_market_condition(index_data)
threshold_map = {
'bull': BULL_THRESHOLD,
'bear': BEAR_THRESHOLD,
'neutral': NEUTRAL_THRESHOLD
}
threshold = threshold_map.get(market_condition, NEUTRAL_THRESHOLD)
# 生成信号时使用动态阈值
signals = generate_trading_signals(
close, df['open'].values, high, low, volume,
macd, signal, atr,
threshold=threshold, # 使用动态阈值
volatility_window=VOLATILITY_WINDOW
)
df['signal'] = signals
# 获取股票代码
stock_code = os.path.basename(file_path).split('_')[0]
return stock_code, df
except Exception as e:
logging.error(f"处理文件 {os.path.basename(file_path)} 失败: {str(e)}")
return None
# ========== 回测分析模块 ==========
def backtest_strategy(all_data, index_data):
"""执行动态持仓周期回测"""
results = []
def backtest_strategy(all_data, index_data):
results = []
for stock_code, data in all_data.items():
if data is None or 'signal' not in data.columns:
continue
try:
signals = data[data['signal']]
for idx in signals.index:
# 获取关键价格数据
entry = data.iloc[idx]
next_day_data = data.iloc[idx + 1] if idx + 1 < len(data) else None
# 严格校验数据有效性(新增校验点)
if next_day_data is None or \
entry['close'] <= 0 or \
next_day_data['open'] <= 0:
continue
# 动态持仓计算
current_date = entry['trade_date']
historical_index = index_data[index_data['trade_date'] <= current_date]
market_condition = get_market_condition(historical_index)
holding_days = HOLDING_DAYS_MAP.get(market_condition, 2)
# 计算收益区间(新增边界校验)
exit_idx = idx + holding_days + 1
if exit_idx >= len(data):
exit_idx = len(data) - 1
exit_data = data.iloc[idx + 1:exit_idx + 1]
# 计算收益率增加try-except保护
try:
entry_price = entry['close']
exit_prices = exit_data['close']
max_profit = (exit_prices.max() - entry_price) / entry_price
max_loss = (exit_prices.min() - entry_price) / entry_price
final_return = (exit_prices.iloc[-1] - entry_price) / entry_price
# 格式化保留小数点后4位
final_return = round(final_return, 4)
max_profit = round(max_profit, 4)
max_loss = round(max_loss, 4)
except ZeroDivisionError:
logging.error(f"零除错误 @ {stock_code} {current_date}")
continue
results.append({
'code': stock_code,
'date': current_date.strftime('%Y-%m-%d'),
'market': market_condition,
'holding_days': holding_days,
'return': final_return,
'max_profit': max_profit,
'max_loss': max_loss
})
except Exception as e:
logging.error(f"处理股票 {stock_code} 时出错: {str(e)}")
return pd.DataFrame(results)
def analyze_results(results_df):
"""分析回测结果"""
if results_df.empty:
logging.warning("无有效交易记录")
return
# 基础统计
total_trades = len(results_df)
annual_return = results_df['return'].mean() * 252
win_rate = len(results_df[results_df['return'] > 0]) / total_trades
profit_factor = results_df[results_df['return'] > 0]['return'].mean() / \
abs(results_df[results_df['return'] < 0]['return'].mean())
print(f"\n策略表现汇总:")
print(f"总交易次数: {total_trades}")
print(f"年化收益率: {annual_return:.2%}")
print(f"胜率: {win_rate:.2%}")
print(f"盈亏比: {profit_factor:.2f}")
# 分市场状态分析
if 'market' in results_df.columns:
market_stats = results_df.groupby('market').agg({
'return': ['mean', 'count'],
'holding_days': 'mean'
})
print("\n分市场状态表现:")
print(market_stats)
# 可视化
plt.figure(figsize=(12, 5))
plt.subplot(121)
results_df['return'].hist(bins=20, alpha=0.7)
plt.title('收益率分布')
plt.xlabel('收益率')
plt.ylabel('频次')
plt.subplot(122)
if 'market' in results_df.columns:
for condition, group in results_df.groupby('market'):
plt.scatter(group['holding_days'], group['return'], alpha=0.5, label=condition)
plt.legend()
plt.axhline(0, color='red', linestyle='--')
plt.title('持仓周期 vs 收益率')
plt.xlabel('持仓天数')
plt.ylabel('收益率')
plt.tight_layout()
plt.show()
# ========== 主程序 ==========
if __name__ == "__main__":
# 配置路径
STOCK_DIR = 'day/'
INDEX_PATH = 'index/000001.SH.txt'
# 加载指数数据
logging.info("正在加载指数数据...")
index_data = load_index_data(INDEX_PATH)
if index_data is None:
exit()
# 并行处理个股数据
logging.info("正在加载个股数据...")
stock_files = [os.path.join(STOCK_DIR, f) for f in os.listdir(STOCK_DIR)
if f.endswith('.txt') and not any(kw in f for kw in ['ST', '*ST', '688'])]
all_data = {}
with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
futures = {executor.submit(process_stock_file, f, index_data): f for f in stock_files}
for future in tqdm(as_completed(futures), total=len(futures)):
result = future.result()
if result:
code, data = result
all_data[code] = data
if not all_data:
logging.error("没有加载到有效股票数据")
exit()
# 执行回测
logging.info("开始回测...")
results_df = backtest_strategy(all_data, index_data)
# 分析结果
analyze_results(results_df)
# 保存结果
if not results_df.empty:
results_df.to_csv('strategy_backtest_results.csv', index=False)
logging.info("回测结果已保存至 strategy_backtest_results.csv")