# -*- coding: utf-8 -*- """ 通用工具函数模块 提供各种常用的工具函数 """ import os import json import pandas as pd import numpy as np from datetime import datetime, timedelta import re from typing import Any, Dict, List, Union, Optional, Tuple def format_date(date: Union[datetime, str], format_str: str = '%Y-%m-%d') -> str: """ 格式化日期 Args: date: 日期对象或日期字符串 format_str: 格式化字符串 Returns: str: 格式化后的日期字符串 """ if isinstance(date, str): try: date_obj = parse_date(date) return date_obj.strftime(format_str) except Exception: return date elif isinstance(date, datetime): return date.strftime(format_str) return str(date) def parse_date(date_str: str, formats: List[str] = None) -> datetime: """ 解析日期字符串 Args: date_str: 日期字符串 formats: 尝试的日期格式列表 Returns: datetime: 解析后的日期对象 """ if formats is None: # 常用的日期格式列表 formats = [ '%Y-%m-%d', '%Y%m%d', '%Y/%m/%d', '%Y-%m-%d %H:%M:%S', '%Y%m%d %H:%M:%S', '%Y/%m/%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y%m%d %H:%M', '%Y/%m/%d %H:%M' ] for fmt in formats: try: return datetime.strptime(date_str, fmt) except ValueError: continue # 如果所有格式都失败,尝试使用pandas的解析 try: return pd.to_datetime(date_str).to_pydatetime() except Exception: raise ValueError(f'无法解析日期字符串: {date_str}') def ensure_directory(directory: str) -> None: """ 确保目录存在,如果不存在则创建 Args: directory: 目录路径 """ if not os.path.exists(directory): try: os.makedirs(directory, exist_ok=True) except Exception as e: raise IOError(f'创建目录失败: {directory}. 错误: {str(e)}') def load_json(file_path: str) -> Dict[str, Any]: """ 加载JSON文件 Args: file_path: JSON文件路径 Returns: dict: 解析后的JSON数据 """ if not os.path.exists(file_path): raise FileNotFoundError(f'文件不存在: {file_path}') try: with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) except json.JSONDecodeError as e: raise ValueError(f'JSON解析错误: {str(e)}') except Exception as e: raise IOError(f'读取文件失败: {str(e)}') def save_json(data: Any, file_path: str, ensure_dir: bool = True, indent: int = 2) -> None: """ 保存数据到JSON文件 Args: data: 要保存的数据 file_path: JSON文件路径 ensure_dir: 是否确保目录存在 indent: 缩进空格数 """ if ensure_dir: # 确保目录存在 dir_path = os.path.dirname(file_path) if dir_path: ensure_directory(dir_path) try: with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=indent) except Exception as e: raise IOError(f'保存文件失败: {str(e)}') def calculate_indicators( df: pd.DataFrame, indicators: List[str], params: Dict[str, Any] = None ) -> pd.DataFrame: """ 计算常用技术指标 Args: df: 包含OHLCV数据的DataFrame indicators: 要计算的指标列表 params: 指标参数字典 Returns: pd.DataFrame: 添加了指标列的DataFrame """ result_df = df.copy() params = params or {} for indicator in indicators: if indicator.lower() == 'ma' or indicator.lower().startswith('ma'): # 移动平均线 window = params.get('ma_window', 5) if indicator.lower() != 'ma': # 尝试从指标名中提取窗口大小,如MA5, MA10等 match = re.search(r'\d+', indicator) if match: window = int(match.group()) result_df[f'MA{window}'] = result_df['close'].rolling(window=window).mean() elif indicator.lower() == 'ema' or indicator.lower().startswith('ema'): # 指数移动平均线 window = params.get('ema_window', 5) if indicator.lower() != 'ema': match = re.search(r'\d+', indicator) if match: window = int(match.group()) result_df[f'EMA{window}'] = result_df['close'].ewm(span=window, adjust=False).mean() elif indicator.lower() == 'macd': # MACD fast_period = params.get('macd_fast', 12) slow_period = params.get('macd_slow', 26) signal_period = params.get('macd_signal', 9) # 计算EMA ema_fast = result_df['close'].ewm(span=fast_period, adjust=False).mean() ema_slow = result_df['close'].ewm(span=slow_period, adjust=False).mean() # 计算MACD线 result_df['MACD'] = ema_fast - ema_slow # 计算信号线 result_df['MACD_Signal'] = result_df['MACD'].ewm(span=signal_period, adjust=False).mean() # 计算柱状图 result_df['MACD_Hist'] = result_df['MACD'] - result_df['MACD_Signal'] elif indicator.lower() == 'rsi' or indicator.lower().startswith('rsi'): # RSI window = params.get('rsi_window', 14) if indicator.lower() != 'rsi': match = re.search(r'\d+', indicator) if match: window = int(match.group()) # 计算价格变化 delta = result_df['close'].diff() gain = (delta.where(delta > 0, 0)).rolling(window=window).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=window).mean() # 计算RS和RSI rs = gain / loss result_df[f'RSI{window}'] = 100 - (100 / (1 + rs)) elif indicator.lower() == 'kdj' or indicator.lower() == 'stochastic': # KDJ指标 n = params.get('kdj_n', 9) m1 = params.get('kdj_m1', 3) m2 = params.get('kdj_m2', 3) # 计算RSV low_n = result_df['low'].rolling(window=n).min() high_n = result_df['high'].rolling(window=n).max() rsv = (result_df['close'] - low_n) / (high_n - low_n) * 100 # 计算K、D、J线 result_df['K'] = rsv.ewm(com=m1-1, adjust=False).mean() result_df['D'] = result_df['K'].ewm(com=m2-1, adjust=False).mean() result_df['J'] = 3 * result_df['K'] - 2 * result_df['D'] elif indicator.lower() == 'bollinger' or indicator.lower() == 'bbands': # 布林带 window = params.get('bollinger_window', 20) std_dev = params.get('bollinger_std', 2) result_df['BB_Middle'] = result_df['close'].rolling(window=window).mean() result_df['BB_Upper'] = result_df['BB_Middle'] + std_dev * result_df['close'].rolling(window=window).std() result_df['BB_Lower'] = result_df['BB_Middle'] - std_dev * result_df['close'].rolling(window=window).std() elif indicator.lower() == 'volume_ma' or indicator.lower().startswith('vma'): # 成交量移动平均线 window = params.get('volume_ma_window', 5) if indicator.lower() != 'volume_ma': match = re.search(r'\d+', indicator) if match: window = int(match.group()) result_df[f'VMA{window}'] = result_df['volume'].rolling(window=window).mean() return result_df def validate_parameters(params: Dict[str, Any], schema: Dict[str, Dict[str, Any]]) -> Tuple[bool, List[str]]: """ 验证参数是否符合要求 Args: params: 要验证的参数字典 schema: 参数验证规则,格式为 {param_name: {rule_type: rule_value, ...}} Returns: tuple: (是否验证通过, 错误信息列表) """ errors = [] for param_name, rules in schema.items(): # 检查参数是否存在且为必填 if 'required' in rules and rules['required']: if param_name not in params or params[param_name] is None: errors.append(f'参数 {param_name} 是必填的') continue # 如果参数不存在且不是必填,则跳过后续检查 if param_name not in params or params[param_name] is None: continue param_value = params[param_name] # 检查类型 if 'type' in rules: expected_type = rules['type'] if not isinstance(param_value, expected_type): errors.append(f'参数 {param_name} 应为 {expected_type.__name__} 类型,实际为 {type(param_value).__name__}') # 检查数值范围 if isinstance(param_value, (int, float)): if 'min' in rules and param_value < rules['min']: errors.append(f"参数 {param_name} 最小值为 {rules['min']}") if 'max' in rules and param_value > rules['max']: errors.append(f"参数 {param_name} 最大值为 {rules['max']}") # 检查字符串长度 if isinstance(param_value, str): if 'min_length' in rules and len(param_value) < rules['min_length']: errors.append(f"参数 {param_name} 长度至少为 {rules['min_length']}") if 'max_length' in rules and len(param_value) > rules['max_length']: errors.append(f"参数 {param_name} 长度最多为 {rules['max_length']}") if 'pattern' in rules and not re.match(rules['pattern'], param_value): errors.append(f"参数 {param_name} 不符合正则表达式规则: {rules['pattern']}") # 检查是否在允许的值列表中 if 'allowed_values' in rules and param_value not in rules['allowed_values']: errors.append(f"参数 {param_name} 必须是以下值之一: {rules['allowed_values']}") return len(errors) == 0, errors def format_number(value: Union[int, float], precision: int = 2, use_comma: bool = True) -> str: """ 格式化数字 Args: value: 要格式化的数字 precision: 小数位数 use_comma: 是否使用千位分隔符 Returns: str: 格式化后的数字字符串 """ if value is None: return '0' try: # 如果是整数,不显示小数 if isinstance(value, int): if use_comma: return f'{value:,}' else: return str(value) # 格式化浮点数 format_str = f'{{:,.{precision}f}}' if use_comma else f'{{:.{precision}f}}' return format_str.format(value) except Exception: return str(value) def calculate_drawdown(portfolio_values: Union[List[float], pd.Series]) -> pd.Series: """ 计算回撤 Args: portfolio_values: 投资组合价值序列 Returns: pd.Series: 回撤序列 """ if isinstance(portfolio_values, list): portfolio_values = pd.Series(portfolio_values) # 计算累计最大值 cumulative_max = portfolio_values.cummax() # 计算回撤 drawdown = (portfolio_values - cumulative_max) / cumulative_max return drawdown def calculate_sharpe_ratio(returns: Union[List[float], pd.Series], risk_free_rate: float = 0.02, annualized: bool = True) -> float: """ 计算夏普比率 Args: returns: 收益率序列 risk_free_rate: 无风险利率 annualized: 是否年化 Returns: float: 夏普比率 """ if isinstance(returns, list): returns = pd.Series(returns) # 计算超额收益 excess_returns = returns - (risk_free_rate / 252) # 假设一年252个交易日 # 计算夏普比率 mean_excess_return = excess_returns.mean() std_excess_return = excess_returns.std() if std_excess_return == 0: return 0 sharpe = mean_excess_return / std_excess_return # 年化 if annualized: sharpe *= np.sqrt(252) return sharpe def chunk_list(lst: List[Any], chunk_size: int) -> List[List[Any]]: """ 将列表切分为指定大小的子列表 Args: lst: 原始列表 chunk_size: 每个子列表的大小 Returns: list: 子列表的列表 """ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] def flatten_list(nested_list: List[List[Any]]) -> List[Any]: """ 展平嵌套列表 Args: nested_list: 嵌套列表 Returns: list: 展平后的列表 """ return [item for sublist in nested_list for item in sublist] def safe_divide(a: float, b: float, default: float = 0.0) -> float: """ 安全除法,避免除零错误 Args: a: 被除数 b: 除数 default: 当除数为零时的默认返回值 Returns: float: 除法结果或默认值 """ if b == 0: return default return a / b def get_file_paths(directory: str, pattern: str = '*') -> List[str]: """ 获取目录下所有匹配模式的文件路径 Args: directory: 目录路径 pattern: 文件匹配模式 Returns: list: 文件路径列表 """ import glob if not os.path.exists(directory): return [] search_pattern = os.path.join(directory, pattern) return glob.glob(search_pattern) def time_function(func): """ 装饰器:测量函数执行时间 Args: func: 要测量的函数 Returns: function: 包装后的函数 """ import functools @functools.wraps(func) def wrapper(*args, **kwargs): start_time = datetime.now() result = func(*args, **kwargs) end_time = datetime.now() duration = (end_time - start_time).total_seconds() print(f'{func.__name__} 执行时间: {duration:.6f} 秒') return result return wrapper