475 lines
14 KiB
Python
475 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
通用工具函数模块
|
||
提供各种常用的工具函数
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
import re
|
||
from typing import Any, Dict, List, Union, Optional, Tuple
|
||
|
||
|
||
def format_date(date: Union[datetime, str], format_str: str = '%Y-%m-%d') -> str:
|
||
"""
|
||
格式化日期
|
||
|
||
Args:
|
||
date: 日期对象或日期字符串
|
||
format_str: 格式化字符串
|
||
|
||
Returns:
|
||
str: 格式化后的日期字符串
|
||
"""
|
||
if isinstance(date, str):
|
||
try:
|
||
date_obj = parse_date(date)
|
||
return date_obj.strftime(format_str)
|
||
except Exception:
|
||
return date
|
||
elif isinstance(date, datetime):
|
||
return date.strftime(format_str)
|
||
return str(date)
|
||
|
||
|
||
def parse_date(date_str: str, formats: List[str] = None) -> datetime:
|
||
"""
|
||
解析日期字符串
|
||
|
||
Args:
|
||
date_str: 日期字符串
|
||
formats: 尝试的日期格式列表
|
||
|
||
Returns:
|
||
datetime: 解析后的日期对象
|
||
"""
|
||
if formats is None:
|
||
# 常用的日期格式列表
|
||
formats = [
|
||
'%Y-%m-%d',
|
||
'%Y%m%d',
|
||
'%Y/%m/%d',
|
||
'%Y-%m-%d %H:%M:%S',
|
||
'%Y%m%d %H:%M:%S',
|
||
'%Y/%m/%d %H:%M:%S',
|
||
'%Y-%m-%d %H:%M',
|
||
'%Y%m%d %H:%M',
|
||
'%Y/%m/%d %H:%M'
|
||
]
|
||
|
||
for fmt in formats:
|
||
try:
|
||
return datetime.strptime(date_str, fmt)
|
||
except ValueError:
|
||
continue
|
||
|
||
# 如果所有格式都失败,尝试使用pandas的解析
|
||
try:
|
||
return pd.to_datetime(date_str).to_pydatetime()
|
||
except Exception:
|
||
raise ValueError(f'无法解析日期字符串: {date_str}')
|
||
|
||
|
||
def ensure_directory(directory: str) -> None:
|
||
"""
|
||
确保目录存在,如果不存在则创建
|
||
|
||
Args:
|
||
directory: 目录路径
|
||
"""
|
||
if not os.path.exists(directory):
|
||
try:
|
||
os.makedirs(directory, exist_ok=True)
|
||
except Exception as e:
|
||
raise IOError(f'创建目录失败: {directory}. 错误: {str(e)}')
|
||
|
||
|
||
def load_json(file_path: str) -> Dict[str, Any]:
|
||
"""
|
||
加载JSON文件
|
||
|
||
Args:
|
||
file_path: JSON文件路径
|
||
|
||
Returns:
|
||
dict: 解析后的JSON数据
|
||
"""
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f'文件不存在: {file_path}')
|
||
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
except json.JSONDecodeError as e:
|
||
raise ValueError(f'JSON解析错误: {str(e)}')
|
||
except Exception as e:
|
||
raise IOError(f'读取文件失败: {str(e)}')
|
||
|
||
|
||
def save_json(data: Any, file_path: str, ensure_dir: bool = True, indent: int = 2) -> None:
|
||
"""
|
||
保存数据到JSON文件
|
||
|
||
Args:
|
||
data: 要保存的数据
|
||
file_path: JSON文件路径
|
||
ensure_dir: 是否确保目录存在
|
||
indent: 缩进空格数
|
||
"""
|
||
if ensure_dir:
|
||
# 确保目录存在
|
||
dir_path = os.path.dirname(file_path)
|
||
if dir_path:
|
||
ensure_directory(dir_path)
|
||
|
||
try:
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=indent)
|
||
except Exception as e:
|
||
raise IOError(f'保存文件失败: {str(e)}')
|
||
|
||
|
||
def calculate_indicators(
|
||
df: pd.DataFrame,
|
||
indicators: List[str],
|
||
params: Dict[str, Any] = None
|
||
) -> pd.DataFrame:
|
||
"""
|
||
计算常用技术指标
|
||
|
||
Args:
|
||
df: 包含OHLCV数据的DataFrame
|
||
indicators: 要计算的指标列表
|
||
params: 指标参数字典
|
||
|
||
Returns:
|
||
pd.DataFrame: 添加了指标列的DataFrame
|
||
"""
|
||
result_df = df.copy()
|
||
params = params or {}
|
||
|
||
for indicator in indicators:
|
||
if indicator.lower() == 'ma' or indicator.lower().startswith('ma'):
|
||
# 移动平均线
|
||
window = params.get('ma_window', 5)
|
||
if indicator.lower() != 'ma':
|
||
# 尝试从指标名中提取窗口大小,如MA5, MA10等
|
||
match = re.search(r'\d+', indicator)
|
||
if match:
|
||
window = int(match.group())
|
||
|
||
result_df[f'MA{window}'] = result_df['close'].rolling(window=window).mean()
|
||
|
||
elif indicator.lower() == 'ema' or indicator.lower().startswith('ema'):
|
||
# 指数移动平均线
|
||
window = params.get('ema_window', 5)
|
||
if indicator.lower() != 'ema':
|
||
match = re.search(r'\d+', indicator)
|
||
if match:
|
||
window = int(match.group())
|
||
|
||
result_df[f'EMA{window}'] = result_df['close'].ewm(span=window, adjust=False).mean()
|
||
|
||
elif indicator.lower() == 'macd':
|
||
# MACD
|
||
fast_period = params.get('macd_fast', 12)
|
||
slow_period = params.get('macd_slow', 26)
|
||
signal_period = params.get('macd_signal', 9)
|
||
|
||
# 计算EMA
|
||
ema_fast = result_df['close'].ewm(span=fast_period, adjust=False).mean()
|
||
ema_slow = result_df['close'].ewm(span=slow_period, adjust=False).mean()
|
||
|
||
# 计算MACD线
|
||
result_df['MACD'] = ema_fast - ema_slow
|
||
# 计算信号线
|
||
result_df['MACD_Signal'] = result_df['MACD'].ewm(span=signal_period, adjust=False).mean()
|
||
# 计算柱状图
|
||
result_df['MACD_Hist'] = result_df['MACD'] - result_df['MACD_Signal']
|
||
|
||
elif indicator.lower() == 'rsi' or indicator.lower().startswith('rsi'):
|
||
# RSI
|
||
window = params.get('rsi_window', 14)
|
||
if indicator.lower() != 'rsi':
|
||
match = re.search(r'\d+', indicator)
|
||
if match:
|
||
window = int(match.group())
|
||
|
||
# 计算价格变化
|
||
delta = result_df['close'].diff()
|
||
gain = (delta.where(delta > 0, 0)).rolling(window=window).mean()
|
||
loss = (-delta.where(delta < 0, 0)).rolling(window=window).mean()
|
||
|
||
# 计算RS和RSI
|
||
rs = gain / loss
|
||
result_df[f'RSI{window}'] = 100 - (100 / (1 + rs))
|
||
|
||
elif indicator.lower() == 'kdj' or indicator.lower() == 'stochastic':
|
||
# KDJ指标
|
||
n = params.get('kdj_n', 9)
|
||
m1 = params.get('kdj_m1', 3)
|
||
m2 = params.get('kdj_m2', 3)
|
||
|
||
# 计算RSV
|
||
low_n = result_df['low'].rolling(window=n).min()
|
||
high_n = result_df['high'].rolling(window=n).max()
|
||
rsv = (result_df['close'] - low_n) / (high_n - low_n) * 100
|
||
|
||
# 计算K、D、J线
|
||
result_df['K'] = rsv.ewm(com=m1-1, adjust=False).mean()
|
||
result_df['D'] = result_df['K'].ewm(com=m2-1, adjust=False).mean()
|
||
result_df['J'] = 3 * result_df['K'] - 2 * result_df['D']
|
||
|
||
elif indicator.lower() == 'bollinger' or indicator.lower() == 'bbands':
|
||
# 布林带
|
||
window = params.get('bollinger_window', 20)
|
||
std_dev = params.get('bollinger_std', 2)
|
||
|
||
result_df['BB_Middle'] = result_df['close'].rolling(window=window).mean()
|
||
result_df['BB_Upper'] = result_df['BB_Middle'] + std_dev * result_df['close'].rolling(window=window).std()
|
||
result_df['BB_Lower'] = result_df['BB_Middle'] - std_dev * result_df['close'].rolling(window=window).std()
|
||
|
||
elif indicator.lower() == 'volume_ma' or indicator.lower().startswith('vma'):
|
||
# 成交量移动平均线
|
||
window = params.get('volume_ma_window', 5)
|
||
if indicator.lower() != 'volume_ma':
|
||
match = re.search(r'\d+', indicator)
|
||
if match:
|
||
window = int(match.group())
|
||
|
||
result_df[f'VMA{window}'] = result_df['volume'].rolling(window=window).mean()
|
||
|
||
return result_df
|
||
|
||
|
||
def validate_parameters(params: Dict[str, Any], schema: Dict[str, Dict[str, Any]]) -> Tuple[bool, List[str]]:
|
||
"""
|
||
验证参数是否符合要求
|
||
|
||
Args:
|
||
params: 要验证的参数字典
|
||
schema: 参数验证规则,格式为 {param_name: {rule_type: rule_value, ...}}
|
||
|
||
Returns:
|
||
tuple: (是否验证通过, 错误信息列表)
|
||
"""
|
||
errors = []
|
||
|
||
for param_name, rules in schema.items():
|
||
# 检查参数是否存在且为必填
|
||
if 'required' in rules and rules['required']:
|
||
if param_name not in params or params[param_name] is None:
|
||
errors.append(f'参数 {param_name} 是必填的')
|
||
continue
|
||
|
||
# 如果参数不存在且不是必填,则跳过后续检查
|
||
if param_name not in params or params[param_name] is None:
|
||
continue
|
||
|
||
param_value = params[param_name]
|
||
|
||
# 检查类型
|
||
if 'type' in rules:
|
||
expected_type = rules['type']
|
||
if not isinstance(param_value, expected_type):
|
||
errors.append(f'参数 {param_name} 应为 {expected_type.__name__} 类型,实际为 {type(param_value).__name__}')
|
||
|
||
# 检查数值范围
|
||
if isinstance(param_value, (int, float)):
|
||
if 'min' in rules and param_value < rules['min']:
|
||
errors.append(f"参数 {param_name} 最小值为 {rules['min']}")
|
||
if 'max' in rules and param_value > rules['max']:
|
||
errors.append(f"参数 {param_name} 最大值为 {rules['max']}")
|
||
|
||
# 检查字符串长度
|
||
if isinstance(param_value, str):
|
||
if 'min_length' in rules and len(param_value) < rules['min_length']:
|
||
errors.append(f"参数 {param_name} 长度至少为 {rules['min_length']}")
|
||
if 'max_length' in rules and len(param_value) > rules['max_length']:
|
||
errors.append(f"参数 {param_name} 长度最多为 {rules['max_length']}")
|
||
if 'pattern' in rules and not re.match(rules['pattern'], param_value):
|
||
errors.append(f"参数 {param_name} 不符合正则表达式规则: {rules['pattern']}")
|
||
|
||
# 检查是否在允许的值列表中
|
||
if 'allowed_values' in rules and param_value not in rules['allowed_values']:
|
||
errors.append(f"参数 {param_name} 必须是以下值之一: {rules['allowed_values']}")
|
||
|
||
return len(errors) == 0, errors
|
||
|
||
|
||
def format_number(value: Union[int, float], precision: int = 2, use_comma: bool = True) -> str:
|
||
"""
|
||
格式化数字
|
||
|
||
Args:
|
||
value: 要格式化的数字
|
||
precision: 小数位数
|
||
use_comma: 是否使用千位分隔符
|
||
|
||
Returns:
|
||
str: 格式化后的数字字符串
|
||
"""
|
||
if value is None:
|
||
return '0'
|
||
|
||
try:
|
||
# 如果是整数,不显示小数
|
||
if isinstance(value, int):
|
||
if use_comma:
|
||
return f'{value:,}'
|
||
else:
|
||
return str(value)
|
||
|
||
# 格式化浮点数
|
||
format_str = f'{{:,.{precision}f}}' if use_comma else f'{{:.{precision}f}}'
|
||
return format_str.format(value)
|
||
except Exception:
|
||
return str(value)
|
||
|
||
|
||
def calculate_drawdown(portfolio_values: Union[List[float], pd.Series]) -> pd.Series:
|
||
"""
|
||
计算回撤
|
||
|
||
Args:
|
||
portfolio_values: 投资组合价值序列
|
||
|
||
Returns:
|
||
pd.Series: 回撤序列
|
||
"""
|
||
if isinstance(portfolio_values, list):
|
||
portfolio_values = pd.Series(portfolio_values)
|
||
|
||
# 计算累计最大值
|
||
cumulative_max = portfolio_values.cummax()
|
||
# 计算回撤
|
||
drawdown = (portfolio_values - cumulative_max) / cumulative_max
|
||
|
||
return drawdown
|
||
|
||
|
||
def calculate_sharpe_ratio(returns: Union[List[float], pd.Series], risk_free_rate: float = 0.02, annualized: bool = True) -> float:
|
||
"""
|
||
计算夏普比率
|
||
|
||
Args:
|
||
returns: 收益率序列
|
||
risk_free_rate: 无风险利率
|
||
annualized: 是否年化
|
||
|
||
Returns:
|
||
float: 夏普比率
|
||
"""
|
||
if isinstance(returns, list):
|
||
returns = pd.Series(returns)
|
||
|
||
# 计算超额收益
|
||
excess_returns = returns - (risk_free_rate / 252) # 假设一年252个交易日
|
||
|
||
# 计算夏普比率
|
||
mean_excess_return = excess_returns.mean()
|
||
std_excess_return = excess_returns.std()
|
||
|
||
if std_excess_return == 0:
|
||
return 0
|
||
|
||
sharpe = mean_excess_return / std_excess_return
|
||
|
||
# 年化
|
||
if annualized:
|
||
sharpe *= np.sqrt(252)
|
||
|
||
return sharpe
|
||
|
||
|
||
def chunk_list(lst: List[Any], chunk_size: int) -> List[List[Any]]:
|
||
"""
|
||
将列表切分为指定大小的子列表
|
||
|
||
Args:
|
||
lst: 原始列表
|
||
chunk_size: 每个子列表的大小
|
||
|
||
Returns:
|
||
list: 子列表的列表
|
||
"""
|
||
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
||
|
||
|
||
def flatten_list(nested_list: List[List[Any]]) -> List[Any]:
|
||
"""
|
||
展平嵌套列表
|
||
|
||
Args:
|
||
nested_list: 嵌套列表
|
||
|
||
Returns:
|
||
list: 展平后的列表
|
||
"""
|
||
return [item for sublist in nested_list for item in sublist]
|
||
|
||
|
||
def safe_divide(a: float, b: float, default: float = 0.0) -> float:
|
||
"""
|
||
安全除法,避免除零错误
|
||
|
||
Args:
|
||
a: 被除数
|
||
b: 除数
|
||
default: 当除数为零时的默认返回值
|
||
|
||
Returns:
|
||
float: 除法结果或默认值
|
||
"""
|
||
if b == 0:
|
||
return default
|
||
return a / b
|
||
|
||
|
||
def get_file_paths(directory: str, pattern: str = '*') -> List[str]:
|
||
"""
|
||
获取目录下所有匹配模式的文件路径
|
||
|
||
Args:
|
||
directory: 目录路径
|
||
pattern: 文件匹配模式
|
||
|
||
Returns:
|
||
list: 文件路径列表
|
||
"""
|
||
import glob
|
||
|
||
if not os.path.exists(directory):
|
||
return []
|
||
|
||
search_pattern = os.path.join(directory, pattern)
|
||
return glob.glob(search_pattern)
|
||
|
||
|
||
def time_function(func):
|
||
"""
|
||
装饰器:测量函数执行时间
|
||
|
||
Args:
|
||
func: 要测量的函数
|
||
|
||
Returns:
|
||
function: 包装后的函数
|
||
"""
|
||
import functools
|
||
|
||
@functools.wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
start_time = datetime.now()
|
||
result = func(*args, **kwargs)
|
||
end_time = datetime.now()
|
||
duration = (end_time - start_time).total_seconds()
|
||
print(f'{func.__name__} 执行时间: {duration:.6f} 秒')
|
||
return result
|
||
|
||
return wrapper
|