2-深度学习挖短线股-3-训练数据计算

发布于:2025-06-27 ⋅ 阅读:(18) ⋅ 点赞:(0)

2-3 合并输入特征

     首先定义了数据预处理函数,将连续 n 天的 K 线数据(如开盘价、收盘价、成交量等)合并为一行特征,同时保留对应的目标标签(buy 列,表示是否应该买入);然后读取股票代码列表,对每只股票的数据进行检查,如果尚未预处理,则读取扩展数据,筛选出 2017 年底前的数据,并调用预处理函数将多日数据合并为单行特征,最后保存预处理后的文件。这种处理方式将时间序列数据转换为适合机器学习模型输入的格式,便于后续进行训练和预测。为了应用时序上的信息,将前10日的指标数据合并到当日,作为输入特征。

       程序的核心功能是将股票的时序数据转换为适合机器学习模型输入的格式,通过滑动窗口方法构建特征矩阵。具体来说,程序将前 N 天的多项技术指标合并为一行特征向量,并将当日的交易信号作为目标值,为后续的预测模型提供数据准备。

功能总结

  1. 数据预处理:删除日期和目标值列,保留技术指标作为特征
  2. 时序特征构建:使用滑动窗口(长度为 FEATURE_N)将历史数据转换为特征向量
  3. 目标值对齐:将当日的buy信号作为对应特征向量的预测目标
  4. 批量处理:对所有符合条件的股票执行相同的预处理操作
# -*- coding: utf-8 -*-
"""
Created on Thu Jun  5 09:20:50 2025
为了应用时序上的信息,将前10日的指标数据合并到当日,作为输入特征
@author: Administrator
"""

import numpy as np  # 导入数值计算库
import pandas as pd  # 导入数据处理库
import os  # 导入操作系统接口库

# 使用前FEATURE_N的K线数据作为输入特征
FEATURE_N = 10  # 定义时间窗口大小,即使用前10天的数据构建特征

# 预处理,将n行数据作为输入特征
def data_preprocessing(df, stk_code, n):
    df = df.copy()  # 创建数据副本,避免修改原始数据
    
    # 删除无效数据列,保留特征数据
    ft_df = df.drop(columns=['date', 'buy'])  # 移除日期和目标值列,保留技术指标
    
    # 返回值
    out_df = pd.DataFrame()  # 初始化输出DataFrame
    
    # 生成新特征数据
    for i in range(n, df.shape[0]):  # 从第n行开始遍历,确保有足够的历史数据
        # 取n行数据
        part_df = ft_df.iloc[i - n : i]  # 获取当前行前n天的技术指标数据
        
        # 将n行合并为一行
        new_ft_df = pd.DataFrame(part_df.values.reshape(1, -1))  # 将n行数据展平为一行
        
        # 添加到输出DataFrame
        out_df = out_df.append(new_ft_df)
    
    # 添加目标值(当日的buy信号)
    out_df['target'] = df.iloc[n:df.shape[0]]['buy'].values  # 将当日的交易信号作为预测目标
    
    # 重置索引并保存
    out_df = out_df.reset_index(drop=True)  # 重置索引
    out_df.to_csv('./baostock/data_pre/{}.csv'.format(stk_code), index=False)  # 保存处理后的数据
    
    return out_df  # 返回处理后的DataFrame

# 主程序:批量处理所有股票
stk_code_file = './stk_data/dp_stock_list.csv'  # 定义股票代码文件路径
stk_list = pd.read_csv(stk_code_file)['code'].tolist()  # 读取股票代码列表

for stk_code in stk_list:  # 遍历所有股票
    # 判断是否已经经过预处理(文件是否存在)
    data_file = './baostock/data_pre/{}.csv'.format(stk_code)  # 定义预处理后的数据文件路径
    
    if not os.path.exists(data_file):  # 检查文件是否已存在
        print('processing {} ...'.format(stk_code))  # 打印正在处理的股票代码
        
        # 读取数据并限制时间范围
        df = pd.read_csv('./baostock/data_ext/{}.csv'.format(stk_code))  # 读取扩展后的股票数据
        df = df[df['date'] <= '2017-12-31']  # 仅保留2017年底前的数据
        
        # 执行数据预处理
        df = data_preprocessing(df, stk_code, FEATURE_N)  # 调用预处理函数


网站公告

今日签到

点亮在社区的每一天
去签到