Files
backtrader/strategy_editor.py
2026-01-17 21:21:30 +08:00

492 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
策略编辑器
用于读取、修改和新增策略,并在保存时自动更新所有需要使用该策略的地方
"""
import os
import re
import sys
import glob
import json
from datetime import datetime
class StrategyEditor:
"""策略编辑器类"""
def __init__(self):
self.project_root = os.path.dirname(os.path.abspath(__file__))
self.strategy_dir = os.path.join(self.project_root, 'strategy')
self.config_file = os.path.join(self.project_root, 'config.py')
self.strategy_init_file = os.path.join(self.strategy_dir, '__init__.py')
# 初始化编辑器
self.current_strategy = None
self.current_strategy_name = None
self.current_strategy_content = None
def list_strategies(self):
"""列出所有可用策略"""
print("\n可用策略列表:")
print("=" * 50)
# 获取策略文件
strategy_files = glob.glob(os.path.join(self.strategy_dir, '*.py'))
for file_path in strategy_files:
if os.path.basename(file_path) in ['__init__.py', 'base_strategy.py', 'market_filter.py']:
continue
file_name = os.path.basename(file_path)
strategy_name = file_name.replace('.py', '')
# 读取文件内容,获取类定义
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 查找类定义
class_match = re.search(r'class\s+(\w+)\s*\(', content)
if class_match:
class_name = class_match.group(1)
print(f"- {class_name} ({file_name})")
else:
print(f"- {strategy_name} ({file_name}) [无类定义]")
def read_strategy(self, strategy_name):
"""读取现有策略"""
# 查找策略文件
strategy_files = glob.glob(os.path.join(self.strategy_dir, '*.py'))
for file_path in strategy_files:
if os.path.basename(file_path) in ['__init__.py', 'base_strategy.py', 'market_filter.py']:
continue
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 检查是否包含该策略类
if f'class {strategy_name}' in content:
self.current_strategy = os.path.basename(file_path)
self.current_strategy_name = strategy_name
self.current_strategy_content = content
print(f"\n已读取策略: {strategy_name} ({self.current_strategy})")
return True
print(f"\n错误: 未找到策略 {strategy_name}")
return False
def create_new_strategy(self, strategy_name):
"""创建新策略"""
# 检查策略名是否合法
if not re.match(r'^[A-Z][A-Za-z0-9_]*$', strategy_name):
print("\n错误: 策略名必须以大写字母开头,只能包含字母、数字和下划线")
return False
# 检查策略是否已存在
strategy_files = glob.glob(os.path.join(self.strategy_dir, '*.py'))
for file_path in strategy_files:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
if f'class {strategy_name}' in content:
print(f"\n错误: 策略 {strategy_name} 已存在")
return False
# 创建策略文件名
file_name = strategy_name.lower() + '_strategy.py'
file_path = os.path.join(self.strategy_dir, file_name)
# 创建策略模板
template = """
import pandas as pd
import numpy as np
import logging
from typing import Dict, List, Tuple, Optional
from .base_strategy import BaseStrategy
"""
# 添加类定义
template += "class " + strategy_name + "(BaseStrategy):\n"
template += " pass\n\n"
# 添加初始化方法
template += " def __init__(self, config: Dict):\n"
template += " super().__init__()\n"
template += " # 使用主日志记录器,确保日志能被正确记录到文件\n"
template += " self.logger = logging.getLogger('backtrader')\n"
template += " \n"
template += " # 策略参数\n"
template += " # 在此处添加策略参数\n\n"
# 添加initialize方法
template += " def initialize(self, data):\n"
template += " # 初始化信号数据\n"
template += " self.signals = None\n\n"
# 添加generate_signals方法
template += " def generate_signals(self, data):\n"
template += " if data.empty:\n"
template += " return pd.DataFrame()\n"
template += " \n"
template += " # 确保数据按日期从旧到新排序\n"
template += " data = data.sort_index()\n"
template += " \n"
template += " signals = pd.DataFrame(index=data.index)\n"
template += " signals['signal'] = 0 # 0:无信号, 1:买入信号, -1:卖出信号\n"
template += " signals['strength'] = 0.0 # 信号强度\n"
template += " signals['reason'] = '' # 信号原因\n"
template += " \n"
template += " # 在此处添加信号生成逻辑\n"
template += " \n"
template += " return signals\n\n"
# 添加execute_trades方法
template += " def execute_trades(self, data, signals):\n"
template += " if signals.empty:\n"
template += " return pd.DataFrame()\n"
template += " \n"
template += " # 确保数据按日期从旧到新排序\n"
template += " data = data.sort_index()\n"
template += " \n"
template += " trades = []\n"
template += " position = 0\n"
template += " \n"
template += " for date, signal in signals.iterrows():\n"
template += " if signal['signal'] == 1 and position == 0: # 买入信号且无持仓\n"
template += " # 在此处添加买入逻辑\n"
template += " pass\n"
template += " elif signal['signal'] == -1 and position > 0: # 卖出信号且有持仓\n"
template += " # 在此处添加卖出逻辑\n"
template += " pass\n"
template += " \n"
template += " return pd.DataFrame(trades)\n\n"
# 添加calculate_returns方法
template += " def calculate_returns(self, data, trades):\n"
template += " if trades.empty:\n"
template += " return 0.0\n"
template += " \n"
template += " # 在此处添加收益率计算逻辑\n"
template += " return 0.0\n"
self.current_strategy = file_name
self.current_strategy_name = strategy_name
self.current_strategy_content = template
print(f"\n已创建新策略: {strategy_name} ({file_name})")
print("请编辑策略内容,然后保存")
return True
def edit_strategy(self):
"""编辑当前策略内容"""
if not self.current_strategy_content:
print("\n错误: 没有加载任何策略")
return False
print(f"\n编辑策略: {self.current_strategy_name}")
print("=" * 50)
print("当前策略内容:")
print(self.current_strategy_content)
print("=" * 50)
print("\n输入新的策略内容 (以 'EOF' 结束输入):")
new_content = []
while True:
line = input()
if line == 'EOF':
break
new_content.append(line)
if new_content:
self.current_strategy_content = '\n'.join(new_content) + '\n'
print(f"\n策略内容已更新")
return True
return False
def save_strategy(self):
"""保存策略并更新所有引用"""
if not self.current_strategy_content or not self.current_strategy_name:
print("\n错误: 没有加载任何策略")
return False
# 1. 保存策略文件
file_path = os.path.join(self.strategy_dir, self.current_strategy)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(self.current_strategy_content)
print(f"\n1. 已保存策略文件: {file_path}")
# 2. 更新 strategy/__init__.py
self._update_strategy_init()
# 3. 更新 config.py
self._update_config()
# 4. 检查并更新其他引用
self._update_other_references()
print(f"\n策略 {self.current_strategy_name} 已保存并更新所有引用")
return True
def _update_strategy_init(self):
"""更新 strategy/__init__.py"""
with open(self.strategy_init_file, 'r', encoding='utf-8') as f:
content = f.read()
# 检查是否已导入该策略
import_statement = f"from .{self.current_strategy.replace('.py', '')} import {self.current_strategy_name}"
if import_statement not in content:
# 添加导入语句
lines = content.split('\n')
import_lines = []
other_lines = []
for line in lines:
if line.startswith('from .') and 'import' in line:
import_lines.append(line)
else:
other_lines.append(line)
# 在现有导入后添加新导入
import_lines.append(import_statement)
# 重新组合内容
content = '\n'.join(import_lines + other_lines)
# 更新 __all__ 列表
all_match = re.search(r'__all__\s*=\s*\[(.*?)\]', content, re.DOTALL)
if all_match:
all_content = all_match.group(1)
all_items = [item.strip().strip('"\'') for item in all_content.split(',')]
if self.current_strategy_name not in all_items:
all_items.append(self.current_strategy_name)
new_all_content = ', '.join([f'"{item}"' for item in all_items])
content = content.replace(all_match.group(0), f'__all__ = [{new_all_content}]')
else:
# 如果没有 __all__添加它
content += f'\n\n__all__ = ["{self.current_strategy_name}"]'
# 保存更新后的文件
with open(self.strategy_init_file, 'w', encoding='utf-8') as f:
f.write(content)
print(f"2. 已更新 {self.strategy_init_file}")
def _update_config(self):
"""更新 config.py"""
with open(self.config_file, 'r', encoding='utf-8') as f:
content = f.read()
# 更新 STRATEGY_LIST
strategy_list_match = re.search(r'STRATEGY_LIST\s*=\s*\[(.*?)\]', content, re.DOTALL)
if strategy_list_match:
list_content = strategy_list_match.group(1)
list_items = [item.strip().strip('"\'') for item in list_content.split(',')]
if self.current_strategy_name not in list_items:
list_items.append(self.current_strategy_name)
new_list_content = ', '.join([f'"{item}"' for item in list_items])
content = content.replace(strategy_list_match.group(0), f'STRATEGY_LIST = [{new_list_content}]')
else:
# 如果没有 STRATEGY_LIST添加它
content += f'\n\n# ========== 策略配置 ==========\nSTRATEGY_LIST = ["{self.current_strategy_name}"]'
# 保存更新后的文件
with open(self.config_file, 'w', encoding='utf-8') as f:
f.write(content)
print(f"3. 已更新 {self.config_file}")
def _update_other_references(self):
"""更新其他引用策略的文件"""
# 检查 main.py
main_file = os.path.join(self.project_root, 'main.py')
with open(main_file, 'r', encoding='utf-8') as f:
content = f.read()
# 检查是否已导入该策略
import_statement = f'from strategy.{self.current_strategy.replace(".py", "")} import {self.current_strategy_name}'
if import_statement not in content:
# 添加导入语句
lines = content.split('\n')
import_lines = []
other_lines = []
for line in lines:
if line.startswith('from strategy.') and 'import' in line:
import_lines.append(line)
else:
other_lines.append(line)
# 在现有导入后添加新导入
import_lines.append(import_statement)
# 重新组合内容
content = '\n'.join(import_lines + other_lines)
# 保存更新后的文件
with open(main_file, 'w', encoding='utf-8') as f:
f.write(content)
print(f"4. 已更新 {main_file}")
def delete_strategy(self, strategy_name):
"""删除策略"""
# 查找策略文件
strategy_files = glob.glob(os.path.join(self.strategy_dir, '*.py'))
for file_path in strategy_files:
if os.path.basename(file_path) in ['__init__.py', 'base_strategy.py', 'market_filter.py']:
continue
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
if f'class {strategy_name}' in content:
# 删除策略文件
os.remove(file_path)
print(f"\n1. 已删除策略文件: {os.path.basename(file_path)}")
# 更新 strategy/__init__.py
self._remove_from_strategy_init(strategy_name)
# 更新 config.py
self._remove_from_config(strategy_name)
# 更新其他文件
self._remove_from_other_files(strategy_name)
print(f"策略 {strategy_name} 已删除并更新所有引用")
return True
print(f"\n错误: 未找到策略 {strategy_name}")
return False
def _remove_from_strategy_init(self, strategy_name):
"""从 strategy/__init__.py 中移除策略"""
with open(self.strategy_init_file, 'r', encoding='utf-8') as f:
content = f.read()
# 移除导入语句
file_name = strategy_name.lower() + '_strategy.py'
import_statement = f"from .{file_name.replace('.py', '')} import {strategy_name}\n"
content = content.replace(import_statement, '')
# 更新 __all__ 列表
all_match = re.search(r'__all__\s*=\s*\[(.*?)\]', content, re.DOTALL)
if all_match:
all_content = all_match.group(1)
all_items = [item.strip().strip('"\'') for item in all_content.split(',')]
if strategy_name in all_items:
all_items.remove(strategy_name)
if all_items:
new_all_content = ', '.join([f'"{item}"' for item in all_items])
content = content.replace(all_match.group(0), f'__all__ = [{new_all_content}]')
else:
# 如果 __all__ 为空,移除它
content = content.replace(all_match.group(0), '')
# 保存更新后的文件
with open(self.strategy_init_file, 'w', encoding='utf-8') as f:
f.write(content)
print(f"2. 已更新 {self.strategy_init_file}")
def _remove_from_config(self, strategy_name):
"""从 config.py 中移除策略"""
with open(self.config_file, 'r', encoding='utf-8') as f:
content = f.read()
# 更新 STRATEGY_LIST
strategy_list_match = re.search(r'STRATEGY_LIST\s*=\s*\[(.*?)\]', content, re.DOTALL)
if strategy_list_match:
list_content = strategy_list_match.group(1)
list_items = [item.strip().strip('"\'') for item in list_content.split(',')]
if strategy_name in list_items:
list_items.remove(strategy_name)
if list_items:
new_list_content = ', '.join([f'"{item}"' for item in list_items])
content = content.replace(strategy_list_match.group(0), f'STRATEGY_LIST = [{new_list_content}]')
else:
# 如果 STRATEGY_LIST 为空,移除它
content = content.replace(strategy_list_match.group(0), '')
# 保存更新后的文件
with open(self.config_file, 'w', encoding='utf-8') as f:
f.write(content)
print(f"3. 已更新 {self.config_file}")
def _remove_from_other_files(self, strategy_name):
"""从其他文件中移除策略引用"""
# 检查 main.py
main_file = os.path.join(self.project_root, 'main.py')
with open(main_file, 'r', encoding='utf-8') as f:
content = f.read()
# 移除导入语句
file_name = strategy_name.lower() + '_strategy.py'
import_statement = f'from strategy.{file_name.replace(".py", "")} import {strategy_name}\n'
if import_statement in content:
content = content.replace(import_statement, '')
# 保存更新后的文件
with open(main_file, 'w', encoding='utf-8') as f:
f.write(content)
print(f"4. 已更新 {main_file}")
def show_menu(self):
"""显示菜单"""
while True:
print("\n" + "=" * 50)
print("策略编辑器")
print("=" * 50)
print("1. 列出所有策略")
print("2. 读取现有策略")
print("3. 新增策略")
print("4. 编辑当前策略")
print("5. 保存当前策略")
print("6. 删除策略")
print("7. 退出")
print("=" * 50)
choice = input("请选择操作: ")
if choice == '1':
self.list_strategies()
elif choice == '2':
strategy_name = input("请输入策略名称: ")
self.read_strategy(strategy_name)
elif choice == '3':
strategy_name = input("请输入新策略名称: ")
self.create_new_strategy(strategy_name)
elif choice == '4':
self.edit_strategy()
elif choice == '5':
self.save_strategy()
elif choice == '6':
strategy_name = input("请输入要删除的策略名称: ")
self.delete_strategy(strategy_name)
elif choice == '7':
print("\n退出策略编辑器")
break
else:
print("\n错误: 无效的选择")
if __name__ == "__main__":
editor = StrategyEditor()
editor.show_menu()