176 lines
5.7 KiB
Python
176 lines
5.7 KiB
Python
"""参数空间定义与生成。
|
||
|
||
用于定义策略参数的搜索空间,支持:
|
||
- range:整数范围
|
||
- list:离散值列表
|
||
- 可扩展:连续分布、对数空间等
|
||
|
||
示例:
|
||
param_space = {
|
||
"short_window": range(3, 21, 2),
|
||
"long_window": range(20, 61, 5),
|
||
"max_hold_days": range(3, 11),
|
||
"stop_loss_pct": [0.03, 0.05, 0.08],
|
||
"take_profit_pct": [0.10, 0.15, 0.20],
|
||
}
|
||
|
||
combinations = generate_param_combinations(param_space)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import itertools
|
||
from typing import Any, Dict, List, Union
|
||
|
||
from utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
def generate_param_combinations(param_space: Dict[str, Union[range, List]]) -> List[Dict[str, Any]]:
|
||
"""根据参数空间生成所有参数组合(笛卡尔积)。
|
||
|
||
参数:
|
||
param_space: 参数名 -> 取值范围/列表的字典
|
||
|
||
返回:
|
||
List[Dict]: 参数组合列表,每个元素是一组参数
|
||
|
||
示例:
|
||
>>> param_space = {"a": range(1, 3), "b": [0.1, 0.2]}
|
||
>>> generate_param_combinations(param_space)
|
||
[
|
||
{"a": 1, "b": 0.1},
|
||
{"a": 1, "b": 0.2},
|
||
{"a": 2, "b": 0.1},
|
||
{"a": 2, "b": 0.2}
|
||
]
|
||
"""
|
||
if not param_space:
|
||
return [{}]
|
||
|
||
# 提取参数名和值列表
|
||
param_names = list(param_space.keys())
|
||
param_values = []
|
||
|
||
for param_name in param_names:
|
||
values = param_space[param_name]
|
||
# 如果是 range 对象,转为列表
|
||
if isinstance(values, range):
|
||
param_values.append(list(values))
|
||
elif isinstance(values, (list, tuple)):
|
||
# 检查是否是 [min, max, step] 格式
|
||
if len(values) == 3 and all(isinstance(v, (int, float)) for v in values):
|
||
min_val, max_val, step = values
|
||
# 生成范围
|
||
if isinstance(step, int) and isinstance(min_val, int):
|
||
param_values.append(list(range(int(min_val), int(max_val) + 1, int(step))))
|
||
else:
|
||
# 浮点数范围
|
||
import numpy as np
|
||
param_values.append(list(np.arange(min_val, max_val + step/2, step)))
|
||
else:
|
||
# 离散值列表
|
||
param_values.append(list(values))
|
||
else:
|
||
# 单个值也转为列表
|
||
param_values.append([values])
|
||
|
||
# 生成笛卡尔积
|
||
combinations = []
|
||
for combo in itertools.product(*param_values):
|
||
param_dict = dict(zip(param_names, combo))
|
||
combinations.append(param_dict)
|
||
|
||
logger.info(f"参数空间大小: {len(combinations)} 组参数组合")
|
||
return combinations
|
||
|
||
|
||
def validate_param_space(param_space: Dict[str, Union[range, List]]) -> bool:
|
||
"""验证参数空间定义的合法性。
|
||
|
||
检查:
|
||
- 参数空间不为空
|
||
- 所有参数至少有一个取值
|
||
- range 对象有效(start < stop)
|
||
|
||
返回:
|
||
bool: True 表示合法,False 表示不合法
|
||
"""
|
||
if not param_space:
|
||
logger.error("参数空间为空")
|
||
return False
|
||
|
||
for param_name, values in param_space.items():
|
||
if isinstance(values, range):
|
||
if len(values) == 0:
|
||
logger.error(f"参数 {param_name} 的 range 为空: {values}")
|
||
return False
|
||
elif isinstance(values, (list, tuple)):
|
||
if len(values) == 0:
|
||
logger.error(f"参数 {param_name} 的列表为空")
|
||
return False
|
||
else:
|
||
logger.warning(f"参数 {param_name} 的值类型未知: {type(values)},将作为单个值处理")
|
||
|
||
return True
|
||
|
||
|
||
def estimate_combinations_count(param_space: Dict[str, Union[range, List]]) -> int:
|
||
"""估算参数组合总数(不实际生成)。
|
||
|
||
用于在生成前评估计算量。
|
||
|
||
返回:
|
||
int: 参数组合总数
|
||
"""
|
||
if not param_space:
|
||
return 0
|
||
|
||
total = 1
|
||
for values in param_space.values():
|
||
if isinstance(values, range):
|
||
total *= len(values)
|
||
elif isinstance(values, (list, tuple)):
|
||
# 检查是否是 [min, max, step] 格式
|
||
if len(values) == 3 and all(isinstance(v, (int, float)) for v in values):
|
||
min_val, max_val, step = values
|
||
total *= int((max_val - min_val) / step) + 1
|
||
else:
|
||
total *= len(values)
|
||
else:
|
||
total *= 1
|
||
|
||
return total
|
||
|
||
|
||
def apply_param_constraints(
|
||
combinations: List[Dict[str, Any]],
|
||
constraint_func: callable
|
||
) -> List[Dict[str, Any]]:
|
||
"""应用参数约束,过滤无效的参数组合。
|
||
|
||
参数:
|
||
combinations: 参数组合列表
|
||
constraint_func: 约束函数,输入params字典,返回True/False
|
||
|
||
返回:
|
||
过滤后的参数组合列表
|
||
|
||
示例:
|
||
>>> combinations = [{"a": 1, "b": 2}, {"a": 3, "b": 2}]
|
||
>>> constraint = lambda p: p["a"] < p["b"]
|
||
>>> apply_param_constraints(combinations, constraint)
|
||
[{"a": 1, "b": 2}]
|
||
"""
|
||
if constraint_func is None:
|
||
return combinations
|
||
|
||
original_count = len(combinations)
|
||
filtered = [combo for combo in combinations if constraint_func(combo)]
|
||
filtered_count = original_count - len(filtered)
|
||
|
||
if filtered_count > 0:
|
||
logger.info(f"约束条件过滤掉 {filtered_count} 组参数,剩余 {len(filtered)} 组")
|
||
|
||
return filtered
|