Files
backtrader/data/stock_filter.py
2026-01-17 21:21:30 +08:00

154 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
个股筛选模块
用于根据不同条件筛选股票如排除ST、排除北交所、排除科创板等
"""
import pandas as pd
import logging
from typing import List, Dict, Optional
from datetime import datetime
class StockFilter:
"""
个股筛选器
根据预设条件筛选股票池
"""
def __init__(self, data_fetcher=None):
"""
初始化个股筛选器
Args:
data_fetcher: 数据获取器实例
"""
self.logger = logging.getLogger('stock_filter')
self.data_fetcher = data_fetcher
self.all_stocks = None # 全市场股票列表
self.logger.info('个股筛选器初始化完成')
def get_stock_pool(self, filter_config: Dict) -> List[str]:
"""
根据筛选配置获取股票池
Args:
filter_config: 筛选配置字典
- include_st: 是否包含ST股票默认False
- include_bse: 是否包含北交所股票默认False
- include_sse_star: 是否包含科创板股票默认False
- include_gem: 是否包含创业板股票默认True
- include_main: 是否包含主板股票默认True
- custom_stocks: 自定义股票列表,优先级最高
Returns:
List[str]: 筛选后的股票代码列表
"""
try:
# 如果有自定义股票列表,直接返回
if filter_config.get('custom_stocks'):
self.logger.info(f'使用自定义股票池: {filter_config["custom_stocks"]}')
return filter_config['custom_stocks']
# 获取全市场股票列表
if self.all_stocks is None:
self._load_all_stocks()
if self.all_stocks is None or self.all_stocks.empty:
self.logger.error('无法获取股票列表,返回空股票池')
return []
# 开始筛选
filtered_stocks = self.all_stocks.copy()
# 1. 筛选ST股票
if not filter_config.get('include_st', False):
filtered_stocks = filtered_stocks[~filtered_stocks['name'].str.contains('ST')]
self.logger.info('排除ST/*ST股票')
# 2. 筛选北交所股票股票代码以8开头
if not filter_config.get('include_bse', False):
filtered_stocks = filtered_stocks[~filtered_stocks['ts_code'].str.startswith('8')]
self.logger.info('排除北交所股票')
# 3. 筛选科创板股票股票代码以688开头
if not filter_config.get('include_sse_star', False):
filtered_stocks = filtered_stocks[~filtered_stocks['ts_code'].str.startswith('688')]
self.logger.info('排除科创板股票')
# 4. 筛选创业板股票股票代码以3开头
if not filter_config.get('include_gem', True):
filtered_stocks = filtered_stocks[~filtered_stocks['ts_code'].str.startswith('3')]
self.logger.info('排除创业板股票')
# 5. 筛选主板股票股票代码以0或6开头
if not filter_config.get('include_main', True):
filtered_stocks = filtered_stocks[
~filtered_stocks['ts_code'].str.startswith('0') &
~filtered_stocks['ts_code'].str.startswith('6')
]
self.logger.info('排除主板股票')
# 转换为股票代码列表
stock_list = filtered_stocks['ts_code'].tolist()
self.logger.info(f'筛选完成,共得到 {len(stock_list)} 只股票')
self.logger.debug(f'筛选后的股票列表: {stock_list}')
return stock_list
except Exception as e:
self.logger.error(f'筛选股票池失败: {e}')
return []
def _load_all_stocks(self):
"""
加载全市场股票列表
"""
try:
if self.data_fetcher is None:
self.logger.error('数据获取器未初始化,无法加载股票列表')
return
# 使用数据获取器获取股票基本信息
self.all_stocks = self.data_fetcher.get_stock_basic()
if self.all_stocks is not None and not self.all_stocks.empty:
self.logger.info(f'成功加载 {len(self.all_stocks)} 只股票的基本信息')
else:
self.logger.error('加载股票基本信息失败')
except Exception as e:
self.logger.error(f'加载全市场股票列表失败: {e}')
def add_custom_filter(self, condition: callable) -> None:
"""
添加自定义筛选条件
Args:
condition: 筛选条件函数接收股票基本信息DataFrame返回布尔Series
"""
if self.all_stocks is not None and not self.all_stocks.empty:
self.all_stocks = self.all_stocks[condition(self.all_stocks)]
self.logger.info('应用自定义筛选条件')
def get_stock_info(self, ts_code: str) -> Optional[Dict]:
"""
获取单只股票的基本信息
Args:
ts_code: 股票代码
Returns:
Optional[Dict]: 股票基本信息字典
"""
if self.all_stocks is None:
self._load_all_stocks()
if self.all_stocks is not None:
stock_info = self.all_stocks[self.all_stocks['ts_code'] == ts_code]
if not stock_info.empty:
return stock_info.iloc[0].to_dict()
return None