# -*- coding: utf-8 -*- """ 个股筛选模块 用于根据不同条件筛选股票,如排除ST、排除北交所、排除科创板等 """ import pandas as pd import logging from typing import List, Dict, Optional from datetime import datetime class StockFilter: """ 个股筛选器 根据预设条件筛选股票池 """ def __init__(self, data_fetcher=None): """ 初始化个股筛选器 Args: data_fetcher: 数据获取器实例 """ self.logger = logging.getLogger('stock_filter') self.data_fetcher = data_fetcher self.all_stocks = None # 全市场股票列表 self.logger.info('个股筛选器初始化完成') def get_stock_pool(self, filter_config: Dict) -> List[str]: """ 根据筛选配置获取股票池 Args: filter_config: 筛选配置字典 - include_st: 是否包含ST股票,默认False - include_bse: 是否包含北交所股票,默认False - include_sse_star: 是否包含科创板股票,默认False - include_gem: 是否包含创业板股票,默认True - include_main: 是否包含主板股票,默认True - custom_stocks: 自定义股票列表,优先级最高 Returns: List[str]: 筛选后的股票代码列表 """ try: # 如果有自定义股票列表,直接返回 if filter_config.get('custom_stocks'): self.logger.info(f'使用自定义股票池: {filter_config["custom_stocks"]}') return filter_config['custom_stocks'] # 获取全市场股票列表 if self.all_stocks is None: self._load_all_stocks() if self.all_stocks is None or self.all_stocks.empty: self.logger.error('无法获取股票列表,返回空股票池') return [] # 开始筛选 filtered_stocks = self.all_stocks.copy() # 1. 筛选ST股票 if not filter_config.get('include_st', False): filtered_stocks = filtered_stocks[~filtered_stocks['name'].str.contains('ST')] self.logger.info('排除ST/*ST股票') # 2. 筛选北交所股票(股票代码以8开头) if not filter_config.get('include_bse', False): filtered_stocks = filtered_stocks[~filtered_stocks['ts_code'].str.startswith('8')] self.logger.info('排除北交所股票') # 3. 筛选科创板股票(股票代码以688开头) if not filter_config.get('include_sse_star', False): filtered_stocks = filtered_stocks[~filtered_stocks['ts_code'].str.startswith('688')] self.logger.info('排除科创板股票') # 4. 筛选创业板股票(股票代码以3开头) if not filter_config.get('include_gem', True): filtered_stocks = filtered_stocks[~filtered_stocks['ts_code'].str.startswith('3')] self.logger.info('排除创业板股票') # 5. 筛选主板股票(股票代码以0或6开头) if not filter_config.get('include_main', True): filtered_stocks = filtered_stocks[ ~filtered_stocks['ts_code'].str.startswith('0') & ~filtered_stocks['ts_code'].str.startswith('6') ] self.logger.info('排除主板股票') # 转换为股票代码列表 stock_list = filtered_stocks['ts_code'].tolist() self.logger.info(f'筛选完成,共得到 {len(stock_list)} 只股票') self.logger.debug(f'筛选后的股票列表: {stock_list}') return stock_list except Exception as e: self.logger.error(f'筛选股票池失败: {e}') return [] def _load_all_stocks(self): """ 加载全市场股票列表 """ try: if self.data_fetcher is None: self.logger.error('数据获取器未初始化,无法加载股票列表') return # 使用数据获取器获取股票基本信息 self.all_stocks = self.data_fetcher.get_stock_basic() if self.all_stocks is not None and not self.all_stocks.empty: self.logger.info(f'成功加载 {len(self.all_stocks)} 只股票的基本信息') else: self.logger.error('加载股票基本信息失败') except Exception as e: self.logger.error(f'加载全市场股票列表失败: {e}') def add_custom_filter(self, condition: callable) -> None: """ 添加自定义筛选条件 Args: condition: 筛选条件函数,接收股票基本信息DataFrame,返回布尔Series """ if self.all_stocks is not None and not self.all_stocks.empty: self.all_stocks = self.all_stocks[condition(self.all_stocks)] self.logger.info('应用自定义筛选条件') def get_stock_info(self, ts_code: str) -> Optional[Dict]: """ 获取单只股票的基本信息 Args: ts_code: 股票代码 Returns: Optional[Dict]: 股票基本信息字典 """ if self.all_stocks is None: self._load_all_stocks() if self.all_stocks is not None: stock_info = self.all_stocks[self.all_stocks['ts_code'] == ts_code] if not stock_info.empty: return stock_info.iloc[0].to_dict() return None