新建回测系统,并提交
This commit is contained in:
175
optimization/param_space.py
Normal file
175
optimization/param_space.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""参数空间定义与生成。
|
||||
|
||||
用于定义策略参数的搜索空间,支持:
|
||||
- range:整数范围
|
||||
- list:离散值列表
|
||||
- 可扩展:连续分布、对数空间等
|
||||
|
||||
示例:
|
||||
param_space = {
|
||||
"short_window": range(3, 21, 2),
|
||||
"long_window": range(20, 61, 5),
|
||||
"max_hold_days": range(3, 11),
|
||||
"stop_loss_pct": [0.03, 0.05, 0.08],
|
||||
"take_profit_pct": [0.10, 0.15, 0.20],
|
||||
}
|
||||
|
||||
combinations = generate_param_combinations(param_space)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def generate_param_combinations(param_space: Dict[str, Union[range, List]]) -> List[Dict[str, Any]]:
|
||||
"""根据参数空间生成所有参数组合(笛卡尔积)。
|
||||
|
||||
参数:
|
||||
param_space: 参数名 -> 取值范围/列表的字典
|
||||
|
||||
返回:
|
||||
List[Dict]: 参数组合列表,每个元素是一组参数
|
||||
|
||||
示例:
|
||||
>>> param_space = {"a": range(1, 3), "b": [0.1, 0.2]}
|
||||
>>> generate_param_combinations(param_space)
|
||||
[
|
||||
{"a": 1, "b": 0.1},
|
||||
{"a": 1, "b": 0.2},
|
||||
{"a": 2, "b": 0.1},
|
||||
{"a": 2, "b": 0.2}
|
||||
]
|
||||
"""
|
||||
if not param_space:
|
||||
return [{}]
|
||||
|
||||
# 提取参数名和值列表
|
||||
param_names = list(param_space.keys())
|
||||
param_values = []
|
||||
|
||||
for param_name in param_names:
|
||||
values = param_space[param_name]
|
||||
# 如果是 range 对象,转为列表
|
||||
if isinstance(values, range):
|
||||
param_values.append(list(values))
|
||||
elif isinstance(values, (list, tuple)):
|
||||
# 检查是否是 [min, max, step] 格式
|
||||
if len(values) == 3 and all(isinstance(v, (int, float)) for v in values):
|
||||
min_val, max_val, step = values
|
||||
# 生成范围
|
||||
if isinstance(step, int) and isinstance(min_val, int):
|
||||
param_values.append(list(range(int(min_val), int(max_val) + 1, int(step))))
|
||||
else:
|
||||
# 浮点数范围
|
||||
import numpy as np
|
||||
param_values.append(list(np.arange(min_val, max_val + step/2, step)))
|
||||
else:
|
||||
# 离散值列表
|
||||
param_values.append(list(values))
|
||||
else:
|
||||
# 单个值也转为列表
|
||||
param_values.append([values])
|
||||
|
||||
# 生成笛卡尔积
|
||||
combinations = []
|
||||
for combo in itertools.product(*param_values):
|
||||
param_dict = dict(zip(param_names, combo))
|
||||
combinations.append(param_dict)
|
||||
|
||||
logger.info(f"参数空间大小: {len(combinations)} 组参数组合")
|
||||
return combinations
|
||||
|
||||
|
||||
def validate_param_space(param_space: Dict[str, Union[range, List]]) -> bool:
|
||||
"""验证参数空间定义的合法性。
|
||||
|
||||
检查:
|
||||
- 参数空间不为空
|
||||
- 所有参数至少有一个取值
|
||||
- range 对象有效(start < stop)
|
||||
|
||||
返回:
|
||||
bool: True 表示合法,False 表示不合法
|
||||
"""
|
||||
if not param_space:
|
||||
logger.error("参数空间为空")
|
||||
return False
|
||||
|
||||
for param_name, values in param_space.items():
|
||||
if isinstance(values, range):
|
||||
if len(values) == 0:
|
||||
logger.error(f"参数 {param_name} 的 range 为空: {values}")
|
||||
return False
|
||||
elif isinstance(values, (list, tuple)):
|
||||
if len(values) == 0:
|
||||
logger.error(f"参数 {param_name} 的列表为空")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"参数 {param_name} 的值类型未知: {type(values)},将作为单个值处理")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def estimate_combinations_count(param_space: Dict[str, Union[range, List]]) -> int:
|
||||
"""估算参数组合总数(不实际生成)。
|
||||
|
||||
用于在生成前评估计算量。
|
||||
|
||||
返回:
|
||||
int: 参数组合总数
|
||||
"""
|
||||
if not param_space:
|
||||
return 0
|
||||
|
||||
total = 1
|
||||
for values in param_space.values():
|
||||
if isinstance(values, range):
|
||||
total *= len(values)
|
||||
elif isinstance(values, (list, tuple)):
|
||||
# 检查是否是 [min, max, step] 格式
|
||||
if len(values) == 3 and all(isinstance(v, (int, float)) for v in values):
|
||||
min_val, max_val, step = values
|
||||
total *= int((max_val - min_val) / step) + 1
|
||||
else:
|
||||
total *= len(values)
|
||||
else:
|
||||
total *= 1
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def apply_param_constraints(
|
||||
combinations: List[Dict[str, Any]],
|
||||
constraint_func: callable
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""应用参数约束,过滤无效的参数组合。
|
||||
|
||||
参数:
|
||||
combinations: 参数组合列表
|
||||
constraint_func: 约束函数,输入params字典,返回True/False
|
||||
|
||||
返回:
|
||||
过滤后的参数组合列表
|
||||
|
||||
示例:
|
||||
>>> combinations = [{"a": 1, "b": 2}, {"a": 3, "b": 2}]
|
||||
>>> constraint = lambda p: p["a"] < p["b"]
|
||||
>>> apply_param_constraints(combinations, constraint)
|
||||
[{"a": 1, "b": 2}]
|
||||
"""
|
||||
if constraint_func is None:
|
||||
return combinations
|
||||
|
||||
original_count = len(combinations)
|
||||
filtered = [combo for combo in combinations if constraint_func(combo)]
|
||||
filtered_count = original_count - len(filtered)
|
||||
|
||||
if filtered_count > 0:
|
||||
logger.info(f"约束条件过滤掉 {filtered_count} 组参数,剩余 {len(filtered)} 组")
|
||||
|
||||
return filtered
|
||||
Reference in New Issue
Block a user