154 lines
5.9 KiB
Python
154 lines
5.9 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
个股筛选模块
|
||
用于根据不同条件筛选股票,如排除ST、排除北交所、排除科创板等
|
||
"""
|
||
|
||
import pandas as pd
|
||
import logging
|
||
from typing import List, Dict, Optional
|
||
from datetime import datetime
|
||
|
||
class StockFilter:
|
||
"""
|
||
个股筛选器
|
||
根据预设条件筛选股票池
|
||
"""
|
||
|
||
def __init__(self, data_fetcher=None):
|
||
"""
|
||
初始化个股筛选器
|
||
|
||
Args:
|
||
data_fetcher: 数据获取器实例
|
||
"""
|
||
self.logger = logging.getLogger('stock_filter')
|
||
self.data_fetcher = data_fetcher
|
||
self.all_stocks = None # 全市场股票列表
|
||
|
||
self.logger.info('个股筛选器初始化完成')
|
||
|
||
def get_stock_pool(self, filter_config: Dict) -> List[str]:
|
||
"""
|
||
根据筛选配置获取股票池
|
||
|
||
Args:
|
||
filter_config: 筛选配置字典
|
||
- include_st: 是否包含ST股票,默认False
|
||
- include_bse: 是否包含北交所股票,默认False
|
||
- include_sse_star: 是否包含科创板股票,默认False
|
||
- include_gem: 是否包含创业板股票,默认True
|
||
- include_main: 是否包含主板股票,默认True
|
||
- custom_stocks: 自定义股票列表,优先级最高
|
||
|
||
Returns:
|
||
List[str]: 筛选后的股票代码列表
|
||
"""
|
||
try:
|
||
# 如果有自定义股票列表,直接返回
|
||
if filter_config.get('custom_stocks'):
|
||
self.logger.info(f'使用自定义股票池: {filter_config["custom_stocks"]}')
|
||
return filter_config['custom_stocks']
|
||
|
||
# 获取全市场股票列表
|
||
if self.all_stocks is None:
|
||
self._load_all_stocks()
|
||
|
||
if self.all_stocks is None or self.all_stocks.empty:
|
||
self.logger.error('无法获取股票列表,返回空股票池')
|
||
return []
|
||
|
||
# 开始筛选
|
||
filtered_stocks = self.all_stocks.copy()
|
||
|
||
# 1. 筛选ST股票
|
||
if not filter_config.get('include_st', False):
|
||
filtered_stocks = filtered_stocks[~filtered_stocks['name'].str.contains('ST')]
|
||
self.logger.info('排除ST/*ST股票')
|
||
|
||
# 2. 筛选北交所股票(股票代码以8开头)
|
||
if not filter_config.get('include_bse', False):
|
||
filtered_stocks = filtered_stocks[~filtered_stocks['ts_code'].str.startswith('8')]
|
||
self.logger.info('排除北交所股票')
|
||
|
||
# 3. 筛选科创板股票(股票代码以688开头)
|
||
if not filter_config.get('include_sse_star', False):
|
||
filtered_stocks = filtered_stocks[~filtered_stocks['ts_code'].str.startswith('688')]
|
||
self.logger.info('排除科创板股票')
|
||
|
||
# 4. 筛选创业板股票(股票代码以3开头)
|
||
if not filter_config.get('include_gem', True):
|
||
filtered_stocks = filtered_stocks[~filtered_stocks['ts_code'].str.startswith('3')]
|
||
self.logger.info('排除创业板股票')
|
||
|
||
# 5. 筛选主板股票(股票代码以0或6开头)
|
||
if not filter_config.get('include_main', True):
|
||
filtered_stocks = filtered_stocks[
|
||
~filtered_stocks['ts_code'].str.startswith('0') &
|
||
~filtered_stocks['ts_code'].str.startswith('6')
|
||
]
|
||
self.logger.info('排除主板股票')
|
||
|
||
# 转换为股票代码列表
|
||
stock_list = filtered_stocks['ts_code'].tolist()
|
||
|
||
self.logger.info(f'筛选完成,共得到 {len(stock_list)} 只股票')
|
||
self.logger.debug(f'筛选后的股票列表: {stock_list}')
|
||
|
||
return stock_list
|
||
|
||
except Exception as e:
|
||
self.logger.error(f'筛选股票池失败: {e}')
|
||
return []
|
||
|
||
def _load_all_stocks(self):
|
||
"""
|
||
加载全市场股票列表
|
||
"""
|
||
try:
|
||
if self.data_fetcher is None:
|
||
self.logger.error('数据获取器未初始化,无法加载股票列表')
|
||
return
|
||
|
||
# 使用数据获取器获取股票基本信息
|
||
self.all_stocks = self.data_fetcher.get_stock_basic()
|
||
|
||
if self.all_stocks is not None and not self.all_stocks.empty:
|
||
self.logger.info(f'成功加载 {len(self.all_stocks)} 只股票的基本信息')
|
||
else:
|
||
self.logger.error('加载股票基本信息失败')
|
||
|
||
except Exception as e:
|
||
self.logger.error(f'加载全市场股票列表失败: {e}')
|
||
|
||
def add_custom_filter(self, condition: callable) -> None:
|
||
"""
|
||
添加自定义筛选条件
|
||
|
||
Args:
|
||
condition: 筛选条件函数,接收股票基本信息DataFrame,返回布尔Series
|
||
"""
|
||
if self.all_stocks is not None and not self.all_stocks.empty:
|
||
self.all_stocks = self.all_stocks[condition(self.all_stocks)]
|
||
self.logger.info('应用自定义筛选条件')
|
||
|
||
def get_stock_info(self, ts_code: str) -> Optional[Dict]:
|
||
"""
|
||
获取单只股票的基本信息
|
||
|
||
Args:
|
||
ts_code: 股票代码
|
||
|
||
Returns:
|
||
Optional[Dict]: 股票基本信息字典
|
||
"""
|
||
if self.all_stocks is None:
|
||
self._load_all_stocks()
|
||
|
||
if self.all_stocks is not None:
|
||
stock_info = self.all_stocks[self.all_stocks['ts_code'] == ts_code]
|
||
if not stock_info.empty:
|
||
return stock_info.iloc[0].to_dict()
|
||
|
||
return None
|