Files
strategy_backtest/optimization/param_space.py

176 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""参数空间定义与生成。
用于定义策略参数的搜索空间,支持:
- range整数范围
- list离散值列表
- 可扩展:连续分布、对数空间等
示例:
param_space = {
"short_window": range(3, 21, 2),
"long_window": range(20, 61, 5),
"max_hold_days": range(3, 11),
"stop_loss_pct": [0.03, 0.05, 0.08],
"take_profit_pct": [0.10, 0.15, 0.20],
}
combinations = generate_param_combinations(param_space)
"""
from __future__ import annotations
import itertools
from typing import Any, Dict, List, Union
from utils.logger import setup_logger
logger = setup_logger(__name__)
def generate_param_combinations(param_space: Dict[str, Union[range, List]]) -> List[Dict[str, Any]]:
"""根据参数空间生成所有参数组合(笛卡尔积)。
参数:
param_space: 参数名 -> 取值范围/列表的字典
返回:
List[Dict]: 参数组合列表,每个元素是一组参数
示例:
>>> param_space = {"a": range(1, 3), "b": [0.1, 0.2]}
>>> generate_param_combinations(param_space)
[
{"a": 1, "b": 0.1},
{"a": 1, "b": 0.2},
{"a": 2, "b": 0.1},
{"a": 2, "b": 0.2}
]
"""
if not param_space:
return [{}]
# 提取参数名和值列表
param_names = list(param_space.keys())
param_values = []
for param_name in param_names:
values = param_space[param_name]
# 如果是 range 对象,转为列表
if isinstance(values, range):
param_values.append(list(values))
elif isinstance(values, (list, tuple)):
# 检查是否是 [min, max, step] 格式
if len(values) == 3 and all(isinstance(v, (int, float)) for v in values):
min_val, max_val, step = values
# 生成范围
if isinstance(step, int) and isinstance(min_val, int):
param_values.append(list(range(int(min_val), int(max_val) + 1, int(step))))
else:
# 浮点数范围
import numpy as np
param_values.append(list(np.arange(min_val, max_val + step/2, step)))
else:
# 离散值列表
param_values.append(list(values))
else:
# 单个值也转为列表
param_values.append([values])
# 生成笛卡尔积
combinations = []
for combo in itertools.product(*param_values):
param_dict = dict(zip(param_names, combo))
combinations.append(param_dict)
logger.info(f"参数空间大小: {len(combinations)} 组参数组合")
return combinations
def validate_param_space(param_space: Dict[str, Union[range, List]]) -> bool:
"""验证参数空间定义的合法性。
检查:
- 参数空间不为空
- 所有参数至少有一个取值
- range 对象有效start < stop
返回:
bool: True 表示合法False 表示不合法
"""
if not param_space:
logger.error("参数空间为空")
return False
for param_name, values in param_space.items():
if isinstance(values, range):
if len(values) == 0:
logger.error(f"参数 {param_name} 的 range 为空: {values}")
return False
elif isinstance(values, (list, tuple)):
if len(values) == 0:
logger.error(f"参数 {param_name} 的列表为空")
return False
else:
logger.warning(f"参数 {param_name} 的值类型未知: {type(values)},将作为单个值处理")
return True
def estimate_combinations_count(param_space: Dict[str, Union[range, List]]) -> int:
"""估算参数组合总数(不实际生成)。
用于在生成前评估计算量。
返回:
int: 参数组合总数
"""
if not param_space:
return 0
total = 1
for values in param_space.values():
if isinstance(values, range):
total *= len(values)
elif isinstance(values, (list, tuple)):
# 检查是否是 [min, max, step] 格式
if len(values) == 3 and all(isinstance(v, (int, float)) for v in values):
min_val, max_val, step = values
total *= int((max_val - min_val) / step) + 1
else:
total *= len(values)
else:
total *= 1
return total
def apply_param_constraints(
combinations: List[Dict[str, Any]],
constraint_func: callable
) -> List[Dict[str, Any]]:
"""应用参数约束,过滤无效的参数组合。
参数:
combinations: 参数组合列表
constraint_func: 约束函数输入params字典返回True/False
返回:
过滤后的参数组合列表
示例:
>>> combinations = [{"a": 1, "b": 2}, {"a": 3, "b": 2}]
>>> constraint = lambda p: p["a"] < p["b"]
>>> apply_param_constraints(combinations, constraint)
[{"a": 1, "b": 2}]
"""
if constraint_func is None:
return combinations
original_count = len(combinations)
filtered = [combo for combo in combinations if constraint_func(combo)]
filtered_count = original_count - len(filtered)
if filtered_count > 0:
logger.info(f"约束条件过滤掉 {filtered_count} 组参数,剩余 {len(filtered)}")
return filtered