492 lines
19 KiB
Python
492 lines
19 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
策略编辑器
|
||
用于读取、修改和新增策略,并在保存时自动更新所有需要使用该策略的地方
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import sys
|
||
import glob
|
||
import json
|
||
from datetime import datetime
|
||
|
||
class StrategyEditor:
|
||
"""策略编辑器类"""
|
||
|
||
def __init__(self):
|
||
self.project_root = os.path.dirname(os.path.abspath(__file__))
|
||
self.strategy_dir = os.path.join(self.project_root, 'strategy')
|
||
self.config_file = os.path.join(self.project_root, 'config.py')
|
||
self.strategy_init_file = os.path.join(self.strategy_dir, '__init__.py')
|
||
|
||
# 初始化编辑器
|
||
self.current_strategy = None
|
||
self.current_strategy_name = None
|
||
self.current_strategy_content = None
|
||
|
||
def list_strategies(self):
|
||
"""列出所有可用策略"""
|
||
print("\n可用策略列表:")
|
||
print("=" * 50)
|
||
|
||
# 获取策略文件
|
||
strategy_files = glob.glob(os.path.join(self.strategy_dir, '*.py'))
|
||
|
||
for file_path in strategy_files:
|
||
if os.path.basename(file_path) in ['__init__.py', 'base_strategy.py', 'market_filter.py']:
|
||
continue
|
||
|
||
file_name = os.path.basename(file_path)
|
||
strategy_name = file_name.replace('.py', '')
|
||
|
||
# 读取文件内容,获取类定义
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 查找类定义
|
||
class_match = re.search(r'class\s+(\w+)\s*\(', content)
|
||
if class_match:
|
||
class_name = class_match.group(1)
|
||
print(f"- {class_name} ({file_name})")
|
||
else:
|
||
print(f"- {strategy_name} ({file_name}) [无类定义]")
|
||
|
||
def read_strategy(self, strategy_name):
|
||
"""读取现有策略"""
|
||
# 查找策略文件
|
||
strategy_files = glob.glob(os.path.join(self.strategy_dir, '*.py'))
|
||
|
||
for file_path in strategy_files:
|
||
if os.path.basename(file_path) in ['__init__.py', 'base_strategy.py', 'market_filter.py']:
|
||
continue
|
||
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 检查是否包含该策略类
|
||
if f'class {strategy_name}' in content:
|
||
self.current_strategy = os.path.basename(file_path)
|
||
self.current_strategy_name = strategy_name
|
||
self.current_strategy_content = content
|
||
|
||
print(f"\n已读取策略: {strategy_name} ({self.current_strategy})")
|
||
return True
|
||
|
||
print(f"\n错误: 未找到策略 {strategy_name}")
|
||
return False
|
||
|
||
def create_new_strategy(self, strategy_name):
|
||
"""创建新策略"""
|
||
# 检查策略名是否合法
|
||
if not re.match(r'^[A-Z][A-Za-z0-9_]*$', strategy_name):
|
||
print("\n错误: 策略名必须以大写字母开头,只能包含字母、数字和下划线")
|
||
return False
|
||
|
||
# 检查策略是否已存在
|
||
strategy_files = glob.glob(os.path.join(self.strategy_dir, '*.py'))
|
||
for file_path in strategy_files:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
if f'class {strategy_name}' in content:
|
||
print(f"\n错误: 策略 {strategy_name} 已存在")
|
||
return False
|
||
|
||
# 创建策略文件名
|
||
file_name = strategy_name.lower() + '_strategy.py'
|
||
file_path = os.path.join(self.strategy_dir, file_name)
|
||
|
||
# 创建策略模板
|
||
template = """
|
||
import pandas as pd
|
||
import numpy as np
|
||
import logging
|
||
from typing import Dict, List, Tuple, Optional
|
||
from .base_strategy import BaseStrategy
|
||
|
||
"""
|
||
|
||
# 添加类定义
|
||
template += "class " + strategy_name + "(BaseStrategy):\n"
|
||
template += " pass\n\n"
|
||
|
||
# 添加初始化方法
|
||
template += " def __init__(self, config: Dict):\n"
|
||
template += " super().__init__()\n"
|
||
template += " # 使用主日志记录器,确保日志能被正确记录到文件\n"
|
||
template += " self.logger = logging.getLogger('backtrader')\n"
|
||
template += " \n"
|
||
template += " # 策略参数\n"
|
||
template += " # 在此处添加策略参数\n\n"
|
||
|
||
# 添加initialize方法
|
||
template += " def initialize(self, data):\n"
|
||
template += " # 初始化信号数据\n"
|
||
template += " self.signals = None\n\n"
|
||
|
||
# 添加generate_signals方法
|
||
template += " def generate_signals(self, data):\n"
|
||
template += " if data.empty:\n"
|
||
template += " return pd.DataFrame()\n"
|
||
template += " \n"
|
||
template += " # 确保数据按日期从旧到新排序\n"
|
||
template += " data = data.sort_index()\n"
|
||
template += " \n"
|
||
template += " signals = pd.DataFrame(index=data.index)\n"
|
||
template += " signals['signal'] = 0 # 0:无信号, 1:买入信号, -1:卖出信号\n"
|
||
template += " signals['strength'] = 0.0 # 信号强度\n"
|
||
template += " signals['reason'] = '' # 信号原因\n"
|
||
template += " \n"
|
||
template += " # 在此处添加信号生成逻辑\n"
|
||
template += " \n"
|
||
template += " return signals\n\n"
|
||
|
||
# 添加execute_trades方法
|
||
template += " def execute_trades(self, data, signals):\n"
|
||
template += " if signals.empty:\n"
|
||
template += " return pd.DataFrame()\n"
|
||
template += " \n"
|
||
template += " # 确保数据按日期从旧到新排序\n"
|
||
template += " data = data.sort_index()\n"
|
||
template += " \n"
|
||
template += " trades = []\n"
|
||
template += " position = 0\n"
|
||
template += " \n"
|
||
template += " for date, signal in signals.iterrows():\n"
|
||
template += " if signal['signal'] == 1 and position == 0: # 买入信号且无持仓\n"
|
||
template += " # 在此处添加买入逻辑\n"
|
||
template += " pass\n"
|
||
template += " elif signal['signal'] == -1 and position > 0: # 卖出信号且有持仓\n"
|
||
template += " # 在此处添加卖出逻辑\n"
|
||
template += " pass\n"
|
||
template += " \n"
|
||
template += " return pd.DataFrame(trades)\n\n"
|
||
|
||
# 添加calculate_returns方法
|
||
template += " def calculate_returns(self, data, trades):\n"
|
||
template += " if trades.empty:\n"
|
||
template += " return 0.0\n"
|
||
template += " \n"
|
||
template += " # 在此处添加收益率计算逻辑\n"
|
||
template += " return 0.0\n"
|
||
|
||
self.current_strategy = file_name
|
||
self.current_strategy_name = strategy_name
|
||
self.current_strategy_content = template
|
||
|
||
print(f"\n已创建新策略: {strategy_name} ({file_name})")
|
||
print("请编辑策略内容,然后保存")
|
||
return True
|
||
|
||
def edit_strategy(self):
|
||
"""编辑当前策略内容"""
|
||
if not self.current_strategy_content:
|
||
print("\n错误: 没有加载任何策略")
|
||
return False
|
||
|
||
print(f"\n编辑策略: {self.current_strategy_name}")
|
||
print("=" * 50)
|
||
print("当前策略内容:")
|
||
print(self.current_strategy_content)
|
||
print("=" * 50)
|
||
print("\n输入新的策略内容 (以 'EOF' 结束输入):")
|
||
|
||
new_content = []
|
||
while True:
|
||
line = input()
|
||
if line == 'EOF':
|
||
break
|
||
new_content.append(line)
|
||
|
||
if new_content:
|
||
self.current_strategy_content = '\n'.join(new_content) + '\n'
|
||
print(f"\n策略内容已更新")
|
||
return True
|
||
|
||
return False
|
||
|
||
def save_strategy(self):
|
||
"""保存策略并更新所有引用"""
|
||
if not self.current_strategy_content or not self.current_strategy_name:
|
||
print("\n错误: 没有加载任何策略")
|
||
return False
|
||
|
||
# 1. 保存策略文件
|
||
file_path = os.path.join(self.strategy_dir, self.current_strategy)
|
||
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
f.write(self.current_strategy_content)
|
||
|
||
print(f"\n1. 已保存策略文件: {file_path}")
|
||
|
||
# 2. 更新 strategy/__init__.py
|
||
self._update_strategy_init()
|
||
|
||
# 3. 更新 config.py
|
||
self._update_config()
|
||
|
||
# 4. 检查并更新其他引用
|
||
self._update_other_references()
|
||
|
||
print(f"\n策略 {self.current_strategy_name} 已保存并更新所有引用")
|
||
return True
|
||
|
||
def _update_strategy_init(self):
|
||
"""更新 strategy/__init__.py"""
|
||
with open(self.strategy_init_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 检查是否已导入该策略
|
||
import_statement = f"from .{self.current_strategy.replace('.py', '')} import {self.current_strategy_name}"
|
||
|
||
if import_statement not in content:
|
||
# 添加导入语句
|
||
lines = content.split('\n')
|
||
import_lines = []
|
||
other_lines = []
|
||
|
||
for line in lines:
|
||
if line.startswith('from .') and 'import' in line:
|
||
import_lines.append(line)
|
||
else:
|
||
other_lines.append(line)
|
||
|
||
# 在现有导入后添加新导入
|
||
import_lines.append(import_statement)
|
||
|
||
# 重新组合内容
|
||
content = '\n'.join(import_lines + other_lines)
|
||
|
||
# 更新 __all__ 列表
|
||
all_match = re.search(r'__all__\s*=\s*\[(.*?)\]', content, re.DOTALL)
|
||
if all_match:
|
||
all_content = all_match.group(1)
|
||
all_items = [item.strip().strip('"\'') for item in all_content.split(',')]
|
||
|
||
if self.current_strategy_name not in all_items:
|
||
all_items.append(self.current_strategy_name)
|
||
new_all_content = ', '.join([f'"{item}"' for item in all_items])
|
||
content = content.replace(all_match.group(0), f'__all__ = [{new_all_content}]')
|
||
else:
|
||
# 如果没有 __all__,添加它
|
||
content += f'\n\n__all__ = ["{self.current_strategy_name}"]'
|
||
|
||
# 保存更新后的文件
|
||
with open(self.strategy_init_file, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
print(f"2. 已更新 {self.strategy_init_file}")
|
||
|
||
def _update_config(self):
|
||
"""更新 config.py"""
|
||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 更新 STRATEGY_LIST
|
||
strategy_list_match = re.search(r'STRATEGY_LIST\s*=\s*\[(.*?)\]', content, re.DOTALL)
|
||
if strategy_list_match:
|
||
list_content = strategy_list_match.group(1)
|
||
list_items = [item.strip().strip('"\'') for item in list_content.split(',')]
|
||
|
||
if self.current_strategy_name not in list_items:
|
||
list_items.append(self.current_strategy_name)
|
||
new_list_content = ', '.join([f'"{item}"' for item in list_items])
|
||
content = content.replace(strategy_list_match.group(0), f'STRATEGY_LIST = [{new_list_content}]')
|
||
else:
|
||
# 如果没有 STRATEGY_LIST,添加它
|
||
content += f'\n\n# ========== 策略配置 ==========\nSTRATEGY_LIST = ["{self.current_strategy_name}"]'
|
||
|
||
# 保存更新后的文件
|
||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
print(f"3. 已更新 {self.config_file}")
|
||
|
||
def _update_other_references(self):
|
||
"""更新其他引用策略的文件"""
|
||
# 检查 main.py
|
||
main_file = os.path.join(self.project_root, 'main.py')
|
||
|
||
with open(main_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 检查是否已导入该策略
|
||
import_statement = f'from strategy.{self.current_strategy.replace(".py", "")} import {self.current_strategy_name}'
|
||
|
||
if import_statement not in content:
|
||
# 添加导入语句
|
||
lines = content.split('\n')
|
||
import_lines = []
|
||
other_lines = []
|
||
|
||
for line in lines:
|
||
if line.startswith('from strategy.') and 'import' in line:
|
||
import_lines.append(line)
|
||
else:
|
||
other_lines.append(line)
|
||
|
||
# 在现有导入后添加新导入
|
||
import_lines.append(import_statement)
|
||
|
||
# 重新组合内容
|
||
content = '\n'.join(import_lines + other_lines)
|
||
|
||
# 保存更新后的文件
|
||
with open(main_file, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
print(f"4. 已更新 {main_file}")
|
||
|
||
def delete_strategy(self, strategy_name):
|
||
"""删除策略"""
|
||
# 查找策略文件
|
||
strategy_files = glob.glob(os.path.join(self.strategy_dir, '*.py'))
|
||
|
||
for file_path in strategy_files:
|
||
if os.path.basename(file_path) in ['__init__.py', 'base_strategy.py', 'market_filter.py']:
|
||
continue
|
||
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
if f'class {strategy_name}' in content:
|
||
# 删除策略文件
|
||
os.remove(file_path)
|
||
print(f"\n1. 已删除策略文件: {os.path.basename(file_path)}")
|
||
|
||
# 更新 strategy/__init__.py
|
||
self._remove_from_strategy_init(strategy_name)
|
||
|
||
# 更新 config.py
|
||
self._remove_from_config(strategy_name)
|
||
|
||
# 更新其他文件
|
||
self._remove_from_other_files(strategy_name)
|
||
|
||
print(f"策略 {strategy_name} 已删除并更新所有引用")
|
||
return True
|
||
|
||
print(f"\n错误: 未找到策略 {strategy_name}")
|
||
return False
|
||
|
||
def _remove_from_strategy_init(self, strategy_name):
|
||
"""从 strategy/__init__.py 中移除策略"""
|
||
with open(self.strategy_init_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 移除导入语句
|
||
file_name = strategy_name.lower() + '_strategy.py'
|
||
import_statement = f"from .{file_name.replace('.py', '')} import {strategy_name}\n"
|
||
content = content.replace(import_statement, '')
|
||
|
||
# 更新 __all__ 列表
|
||
all_match = re.search(r'__all__\s*=\s*\[(.*?)\]', content, re.DOTALL)
|
||
if all_match:
|
||
all_content = all_match.group(1)
|
||
all_items = [item.strip().strip('"\'') for item in all_content.split(',')]
|
||
|
||
if strategy_name in all_items:
|
||
all_items.remove(strategy_name)
|
||
if all_items:
|
||
new_all_content = ', '.join([f'"{item}"' for item in all_items])
|
||
content = content.replace(all_match.group(0), f'__all__ = [{new_all_content}]')
|
||
else:
|
||
# 如果 __all__ 为空,移除它
|
||
content = content.replace(all_match.group(0), '')
|
||
|
||
# 保存更新后的文件
|
||
with open(self.strategy_init_file, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
print(f"2. 已更新 {self.strategy_init_file}")
|
||
|
||
def _remove_from_config(self, strategy_name):
|
||
"""从 config.py 中移除策略"""
|
||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 更新 STRATEGY_LIST
|
||
strategy_list_match = re.search(r'STRATEGY_LIST\s*=\s*\[(.*?)\]', content, re.DOTALL)
|
||
if strategy_list_match:
|
||
list_content = strategy_list_match.group(1)
|
||
list_items = [item.strip().strip('"\'') for item in list_content.split(',')]
|
||
|
||
if strategy_name in list_items:
|
||
list_items.remove(strategy_name)
|
||
if list_items:
|
||
new_list_content = ', '.join([f'"{item}"' for item in list_items])
|
||
content = content.replace(strategy_list_match.group(0), f'STRATEGY_LIST = [{new_list_content}]')
|
||
else:
|
||
# 如果 STRATEGY_LIST 为空,移除它
|
||
content = content.replace(strategy_list_match.group(0), '')
|
||
|
||
# 保存更新后的文件
|
||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
print(f"3. 已更新 {self.config_file}")
|
||
|
||
def _remove_from_other_files(self, strategy_name):
|
||
"""从其他文件中移除策略引用"""
|
||
# 检查 main.py
|
||
main_file = os.path.join(self.project_root, 'main.py')
|
||
|
||
with open(main_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 移除导入语句
|
||
file_name = strategy_name.lower() + '_strategy.py'
|
||
import_statement = f'from strategy.{file_name.replace(".py", "")} import {strategy_name}\n'
|
||
|
||
if import_statement in content:
|
||
content = content.replace(import_statement, '')
|
||
|
||
# 保存更新后的文件
|
||
with open(main_file, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
print(f"4. 已更新 {main_file}")
|
||
|
||
def show_menu(self):
|
||
"""显示菜单"""
|
||
while True:
|
||
print("\n" + "=" * 50)
|
||
print("策略编辑器")
|
||
print("=" * 50)
|
||
print("1. 列出所有策略")
|
||
print("2. 读取现有策略")
|
||
print("3. 新增策略")
|
||
print("4. 编辑当前策略")
|
||
print("5. 保存当前策略")
|
||
print("6. 删除策略")
|
||
print("7. 退出")
|
||
print("=" * 50)
|
||
|
||
choice = input("请选择操作: ")
|
||
|
||
if choice == '1':
|
||
self.list_strategies()
|
||
elif choice == '2':
|
||
strategy_name = input("请输入策略名称: ")
|
||
self.read_strategy(strategy_name)
|
||
elif choice == '3':
|
||
strategy_name = input("请输入新策略名称: ")
|
||
self.create_new_strategy(strategy_name)
|
||
elif choice == '4':
|
||
self.edit_strategy()
|
||
elif choice == '5':
|
||
self.save_strategy()
|
||
elif choice == '6':
|
||
strategy_name = input("请输入要删除的策略名称: ")
|
||
self.delete_strategy(strategy_name)
|
||
elif choice == '7':
|
||
print("\n退出策略编辑器")
|
||
break
|
||
else:
|
||
print("\n错误: 无效的选择")
|
||
|
||
if __name__ == "__main__":
|
||
editor = StrategyEditor()
|
||
editor.show_menu() |