Transformer - 特征预处理

发布于:2024-04-24 ⋅ 阅读:(40) ⋅ 点赞:(0)

Transformer - 特征预处理

flyfish
在这里插入图片描述
原始数据

train_data.values 
[[ 5.827  2.009  1.599  0.462  4.203  1.34  30.531]
 [ 5.76   2.076  1.492  0.426  4.264  1.401 30.46 ]
 [ 5.76   1.942  1.492  0.391  4.234  1.31  30.038]
 [ 5.76   1.942  1.492  0.426  4.234  1.31  27.013]
 [ 5.693  2.076  1.492  0.426  4.142  1.371 27.787]
 [ 5.492  1.942  1.457  0.391  4.112  1.279 27.717]
 [ 5.358  1.875  1.35   0.355  3.929  1.34  27.646]
 [ 5.157  1.808  1.35   0.32   3.807  1.279 27.084]
 [ 5.157  1.741  1.279  0.355  3.777  1.218 27.787]
 [ 5.157  1.808  1.35   0.426  3.777  1.188 27.506]
 [ 5.157  1.808  1.315  0.391  3.777  1.249 27.857]
 [ 5.157  1.942  1.35   0.426  3.807  1.279 27.013]
 [ 5.09   1.942  1.279  0.391  3.807  1.279 25.044]
 [ 5.224  2.009  1.457  0.533  3.807  1.249 24.551]
 [ 5.291  1.808  1.457  0.426  3.777  1.218 23.566]
 [ 5.358  1.942  1.492  0.462  3.807  1.31  21.526]
 [ 5.358  1.942  1.492  0.462  3.868  1.279 21.948]
 [ 5.492  2.009  1.492  0.462  3.929  1.34  21.456]
 [ 5.492  1.942  1.492  0.426  3.929  1.34  22.792]
 [ 5.492  2.076  1.492  0.497  3.99   1.31  21.034]
 [ 5.626  2.143  1.528  0.533  4.051  1.371 21.174]
 [ 5.961  2.344  1.67   0.604  4.234  1.492 20.823]
 [ 6.162  2.411  1.777  0.604  4.325  1.523 21.174]
 [ 6.631  2.478  1.99   0.746  4.66   1.675 21.174]
 [ 7.167  2.947  2.132  0.782  5.026  1.858 22.792]
 [ 7.502  3.215  2.239  0.888  5.33   1.98  23.848]
 [ 7.703  3.349  2.487  1.031  5.269  1.919 24.34 ]
 ......

通过sklearn的fit和transform将数据规范化

train_data = df_data[border1s[0]:border2s[0]]
self.scaler.fit(train_data.values)
data = self.scaler.transform(df_data.values)

规划化后的数据就是将要训练的数据

transform_data [[ 0.6156 -1.3896 -0.991  ...  1.1402 -0.9535  2.907 ]
 [ 0.5294 -1.2429 -1.2435 ...  1.2172 -0.6099  2.8853]
 [ 0.5294 -1.5362 -1.2435 ...  1.1794 -1.1224  2.7561]
 ...
 [ 5.6959  5.6479 11.0826 ... -0.82    0.933   1.5291]
 [ 7.1602  6.0879 13.2628 ... -0.897   1.1076  1.5077]
 [ 6.8156  5.3546 14.3529 ... -0.82    1.2765  1.5508]]

可以通过inverse_transform将数据还原

def inverse_transform(self, data):
    return self.scaler.inverse_transform(data)
inverse_transform_data: [[ 5.827  2.009  1.599 ...  4.203  1.34  30.531]
 [ 5.76   2.076  1.492 ...  4.264  1.401 30.46 ]
 [ 5.76   1.942  1.492 ...  4.234  1.31  30.038]
 ...
 [ 9.779  5.224  6.716 ...  2.65   1.675 26.028]
 [10.918  5.425  7.64  ...  2.589  1.706 25.958]
 [10.65   5.09   8.102 ...  2.65   1.736 26.099]]
 ......

配置

seq_len:24
label_len:12
pred_len:24
set_type:0
features:M
target:OT
scale:True
timeenc:1
freq:h
root_path:./dataset/ETT-small/
data_path:ETTm1.csv
scaler:StandardScaler()

data_x是训练数据

data_x:[[ 6.1557e-01 -1.3896e+00 -9.9100e-01 -1.3248e+00  1.1402e+00 -9.5346e-01
   2.9070e+00]
 [ 5.2944e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00  1.2172e+00 -6.0995e-01
   2.8853e+00]
 [ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.5260e+00  1.1794e+00 -1.1224e+00
   2.7561e+00]
 [ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00  1.1794e+00 -1.1224e+00
   1.8305e+00]
 [ 4.4331e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00  1.0633e+00 -7.7889e-01
   2.0673e+00]
 [ 1.8492e-01 -1.5362e+00 -1.3260e+00 -1.5260e+00  1.0254e+00 -1.2970e+00
   2.0459e+00]
 [ 1.2660e-02 -1.6829e+00 -1.5785e+00 -1.6280e+00  7.9439e-01 -9.5346e-01
   2.0242e+00]

时间数据的编码
具体看这里
原值

df_stamp['date'].values: [
 '2016-07-01T00:00:00.000000000' '2016-07-01T00:15:00.000000000'
 '2016-07-01T00:30:00.000000000' '2016-07-01T00:45:00.000000000'
 '2016-07-01T01:00:00.000000000' '2016-07-01T01:15:00.000000000'
 '2016-07-01T01:30:00.000000000' '2016-07-01T01:45:00.000000000'
 '2016-07-01T02:00:00.000000000' '2016-07-01T02:15:00.000000000'
 '2016-07-01T02:30:00.000000000' '2016-07-01T02:45:00.000000000'
 '2016-07-01T03:00:00.000000000' '2016-07-01T03:15:00.000000000'
 '2016-07-01T03:30:00.000000000' '2016-07-01T03:45:00.000000000'
 '2016-07-01T04:00:00.000000000' '2016-07-01T04:15:00.000000000'
 '2016-07-01T04:30:00.000000000' '2016-07-01T04:45:00.000000000'
 '2016-07-01T05:00:00.000000000' '2016-07-01T05:15:00.000000000'
 ......

编码之后

data_stamp: [[-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
s_begin: 0
s_end: 24
r_begin: 12
r_end: 48

s_begin: 1
s_end: 25
r_begin: 13
r_end: 49

......
seq_x: [[ 6.1557e-01 -1.3896e+00 -9.9100e-01 -1.3248e+00  1.1402e+00 -9.5346e-01
   2.9070e+00]
 [ 5.2944e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00  1.2172e+00 -6.0995e-01
   2.8853e+00]
 [ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.5260e+00  1.1794e+00 -1.1224e+00
   2.7561e+00]
 [ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00  1.1794e+00 -1.1224e+00
   1.8305e+00]
 [ 4.4331e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00  1.0633e+00 -7.7889e-01
   2.0673e+00]
 [ 1.8492e-01 -1.5362e+00 -1.3260e+00 -1.5260e+00  1.0254e+00 -1.2970e+00
   2.0459e+00]
 [ 1.2660e-02 -1.6829e+00 -1.5785e+00 -1.6280e+00  7.9439e-01 -9.5346e-01
   2.0242e+00]
 [-2.4573e-01 -1.8295e+00 -1.5785e+00 -1.7271e+00  6.4040e-01 -1.2970e+00
   1.8522e+00]
 [-2.4573e-01 -1.9762e+00 -1.7460e+00 -1.6280e+00  6.0253e-01 -1.6405e+00
   2.0673e+00]
 [-2.4573e-01 -1.8295e+00 -1.5785e+00 -1.4268e+00  6.0253e-01 -1.8094e+00
   1.9814e+00]
 [-2.4573e-01 -1.8295e+00 -1.6611e+00 -1.5260e+00  6.0253e-01 -1.4659e+00
   2.0888e+00]
 [-2.4573e-01 -1.5362e+00 -1.5785e+00 -1.4268e+00  6.4040e-01 -1.2970e+00
   1.8305e+00]
 [-3.3186e-01 -1.5362e+00 -1.7460e+00 -1.5260e+00  6.4040e-01 -1.2970e+00
   1.2280e+00]
 [-1.5960e-01 -1.3896e+00 -1.3260e+00 -1.1237e+00  6.4040e-01 -1.4659e+00
   1.0771e+00]
 [-7.3470e-02 -1.8295e+00 -1.3260e+00 -1.4268e+00  6.0253e-01 -1.6405e+00
   7.7573e-01]
 [ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00  6.4040e-01 -1.1224e+00
   1.5150e-01]
 [ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00  7.1740e-01 -1.2970e+00
   2.8063e-01]
 [ 1.8492e-01 -1.3896e+00 -1.2435e+00 -1.3248e+00  7.9439e-01 -9.5346e-01
   1.3008e-01]
 [ 1.8492e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00  7.9439e-01 -9.5346e-01
   5.3889e-01]
 [ 1.8492e-01 -1.2429e+00 -1.2435e+00 -1.2257e+00  8.7139e-01 -1.1224e+00
   9.5210e-04]
 [ 3.5718e-01 -1.0962e+00 -1.1585e+00 -1.1237e+00  9.4839e-01 -7.7889e-01
   4.3791e-02]
 [ 7.8783e-01 -6.5627e-01 -8.2347e-01 -9.2254e-01  1.1794e+00 -9.7496e-02
  -6.3613e-02]
 [ 1.0462e+00 -5.0961e-01 -5.7100e-01 -9.2254e-01  1.2942e+00  7.7075e-02
   4.3791e-02]
 [ 1.6491e+00 -3.6295e-01 -6.8426e-02 -5.2023e-01  1.7171e+00  9.3304e-01
   4.3791e-02]]
seq_y: [[-3.3186e-01 -1.5362e+00 -1.7460e+00 -1.5260e+00  6.4040e-01 -1.2970e+00
  1.2280e+00]
[-1.5960e-01 -1.3896e+00 -1.3260e+00 -1.1237e+00  6.4040e-01 -1.4659e+00
  1.0771e+00]
[-7.3470e-02 -1.8295e+00 -1.3260e+00 -1.4268e+00  6.0253e-01 -1.6405e+00
  7.7573e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00  6.4040e-01 -1.1224e+00
  1.5150e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00  7.1740e-01 -1.2970e+00
  2.8063e-01]
[ 1.8492e-01 -1.3896e+00 -1.2435e+00 -1.3248e+00  7.9439e-01 -9.5346e-01
  1.3008e-01]
[ 1.8492e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00  7.9439e-01 -9.5346e-01
  5.3889e-01]
[ 1.8492e-01 -1.2429e+00 -1.2435e+00 -1.2257e+00  8.7139e-01 -1.1224e+00
  9.5210e-04]
[ 3.5718e-01 -1.0962e+00 -1.1585e+00 -1.1237e+00  9.4839e-01 -7.7889e-01
  4.3791e-02]
[ 7.8783e-01 -6.5627e-01 -8.2347e-01 -9.2254e-01  1.1794e+00 -9.7496e-02
 -6.3613e-02]
[ 1.0462e+00 -5.0961e-01 -5.7100e-01 -9.2254e-01  1.2942e+00  7.7075e-02
  4.3791e-02]
[ 1.6491e+00 -3.6295e-01 -6.8426e-02 -5.2023e-01  1.7171e+00  9.3304e-01
  4.3791e-02]
[ 2.3382e+00  6.6367e-01  2.6662e-01 -4.1824e-01  2.1791e+00  1.9636e+00
  5.3889e-01]
[ 2.7688e+00  1.2503e+00  5.1909e-01 -1.1793e-01  2.5628e+00  2.6506e+00
  8.6202e-01]
[ 3.0272e+00  1.5436e+00  1.1043e+00  2.8720e-01  2.4858e+00  2.3071e+00
  1.0126e+00]
[ 2.6827e+00  1.1037e+00  6.8662e-01  8.3219e-02  2.2169e+00  2.1325e+00
  6.4660e-01]
[ 2.6827e+00  1.3970e+00  6.8662e-01  2.8720e-01  2.2561e+00  4.0246e+00
  6.4660e-01]
[ 2.9411e+00  1.3970e+00  6.0168e-01  3.8636e-01  2.5628e+00  3.1630e+00
  8.4030e-01]
[ 2.9411e+00  1.6903e+00  6.8662e-01  4.8835e-01  2.4858e+00  2.3071e+00
  8.6202e-01]
[ 2.8549e+00  1.3970e+00  6.8662e-01  3.8636e-01  2.4479e+00  2.3071e+00
  9.0486e-01]
[ 2.7105e-01  8.1033e-01  1.0217e+00  6.8951e-01 -4.3503e-01 -4.3537e-01
  1.9465e-01]
[-1.5960e-01  5.1701e-01  8.5414e-01  6.8951e-01 -6.2815e-01 -6.0995e-01
  3.2378e-01]
[-2.4573e-01  3.7035e-01  6.8662e-01  6.8951e-01 -5.8902e-01 -4.3537e-01
  4.0976e-01]
[-3.3186e-01 -5.0961e-01 -4.8842e-01 -1.1793e-01 -7.0514e-01 -2.6644e-01
 -1.4198e+00]
[-1.0196e+00 -2.1629e-01 -2.3595e-01 -3.1908e-01 -7.8214e-01 -7.7889e-01
 -1.0970e+00]
[-1.0196e+00 -6.5627e-01 -5.7100e-01 -4.1824e-01 -7.4301e-01 -6.0995e-01
 -1.0110e+00]
[-8.4735e-01 -2.1629e-01 -3.2089e-01 -2.1709e-01 -6.2815e-01 -2.6644e-01
 -7.7413e-01]
[-8.4735e-01 -5.0961e-01 -2.3595e-01 -1.1793e-01 -6.2815e-01 -7.7889e-01
 -5.3729e-01]
[-5.0283e-01 -2.1629e-01 -6.8426e-02 -2.1709e-01 -4.3503e-01 -9.7496e-02
 -3.2187e-01]
[ 9.8790e-02  7.7033e-02  4.3415e-01 -1.1793e-01 -2.0530e-01  2.4601e-01
 -7.7413e-01]
[ 1.8492e-01  7.7033e-02  3.5157e-01 -2.1709e-01 -5.1305e-02  2.4601e-01
 -7.5241e-01]
[ 7.0170e-01  2.2369e-01  8.5414e-01  8.3219e-02  1.4055e-01  2.4601e-01
 -4.9415e-01]
[ 5.2944e-01 -2.1629e-01  4.3415e-01 -2.1709e-01  1.7968e-01 -9.7496e-02
 -2.7903e-01]
[ 3.5718e-01  2.2369e-01  3.5157e-01 -2.1709e-01  6.3558e-02  5.8952e-01
 -3.8644e-01]
[-1.3641e+00 -1.0962e+00 -2.0811e+00 -1.4268e+00 -1.6617e-01  7.6410e-01
 -3.8644e-01]
[-1.3641e+00 -5.0961e-01 -1.0736e+00 -1.4268e+00 -2.4317e-01  4.2059e-01
 -2.1447e-01]]
seq_x_mark: [[-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.5     0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.4565  0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.413   0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]]
seq_y_mark: [[-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3696  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.3261  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2826  0.1667 -0.5    -0.0014]
 [-0.2391  0.1667 -0.5    -0.0014]
 [-0.2391  0.1667 -0.5    -0.0014]
 [-0.2391  0.1667 -0.5    -0.0014]
 [-0.2391  0.1667 -0.5    -0.0014]
 [-0.1957  0.1667 -0.5    -0.0014]
 [-0.1957  0.1667 -0.5    -0.0014]
 [-0.1957  0.1667 -0.5    -0.0014]
 [-0.1957  0.1667 -0.5    -0.0014]
 [-0.1522  0.1667 -0.5    -0.0014]
 [-0.1522  0.1667 -0.5    -0.0014]
 [-0.1522  0.1667 -0.5    -0.0014]
 [-0.1522  0.1667 -0.5    -0.0014]
 [-0.1087  0.1667 -0.5    -0.0014]
 [-0.1087  0.1667 -0.5    -0.0014]
 [-0.1087  0.1667 -0.5    -0.0014]
 [-0.1087  0.1667 -0.5    -0.0014]
 [-0.0652  0.1667 -0.5    -0.0014]
 [-0.0652  0.1667 -0.5    -0.0014]
 [-0.0652  0.1667 -0.5    -0.0014]
 [-0.0652  0.1667 -0.5    -0.0014]
 [-0.0217  0.1667 -0.5    -0.0014]
 [-0.0217  0.1667 -0.5    -0.0014]
 [-0.0217  0.1667 -0.5    -0.0014]
 [-0.0217  0.1667 -0.5    -0.0014]]

代码

class Dataset_Custom(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, timeenc=0, freq='h'):
        # size [seq_len, label_len, pred_len]
        # info
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.root_path = root_path
        self.data_path = data_path
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path,
                                          self.data_path))

        '''
        df_raw.columns: ['date', ...(other features), target feature]
        '''
        cols = list(df_raw.columns)
        cols.remove(self.target)
        cols.remove('date')
        df_raw = df_raw[['date'] + cols + [self.target]]
        # print(cols)
        # num_train = int(len(df_raw) * 0.7)
        # print("num_train:",num_train)
        # num_test = int(len(df_raw) * 0.2)
        
        num_train = int(len(df_raw) * 0.5)
        print("num_train:",num_train)
        num_test = int(len(df_raw) * 0.2)
        
        
        num_vali = len(df_raw) - num_train - num_test
        border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
        border2s = [num_train, num_train + num_vali, len(df_raw)]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
            
            #--------------------------------------------------------------------
            print("train_data.values",train_data.values)
            print("transform_data",data)
            inverse_transform_data=self.inverse_transform(data)
            print("inverse_transform_data:",inverse_transform_data)
             #--------------------------------------------------------------------
            
            
        else:
            data = df_data.values

        df_stamp = df_raw[['date']][border1:border2]
        df_stamp['date'] = pd.to_datetime(df_stamp.date)
        
  
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['date'], 1).values
        elif self.timeenc == 1:
            print("df_stamp['date'].values:",df_stamp['date'].values)
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp
        print("data_stamp:",data_stamp)
        

        print('\n'.join(['%s:%s' % item for item in self.__dict__.items()]) )


    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len
        
        print("s_begin:",s_begin)
        print("s_end:",s_end)
        print("r_begin:",r_begin)
        print("r_end:",r_end)

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]
        
        print("seq_x.shape:",seq_x.shape)
        print("seq_y.shape:",seq_y.shape)
        print("seq_x_mark.shape:",seq_x_mark.shape)
        print("seq_y_mark.shape:",seq_y_mark.shape)
        print("seq_x:",seq_x)
        print("seq_y:",seq_y)
        print("seq_x_mark:",seq_x_mark)
        print("seq_y_mark:",seq_y_mark)

        

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)

对比下形状

 for i, (batch_x, , , ): torch.Size([1, 24, 7])
 for i, (, batch_y, , ): torch.Size([1, 36, 7])
 for i, (, , batch_x_mark, ): torch.Size([1, 24, 4])
 for i, (, , , batch_y_mark): torch.Size([1, 36, 4])
seq_x.shape: (24, 7)
seq_y.shape: (36, 7)
seq_x_mark.shape: (24, 4)
seq_y_mark.shape: (36, 4)

训练数据用

先用batch_x, batch_y, batch_x_mark, batch_y_mark作为参数

 outputs, batch_y = self._predict(batch_x, batch_y, batch_x_mark, batch_y_mark)
def _predict(self, batch_x, batch_y, batch_x_mark, batch_y_mark):
    # decoder input
    dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
    dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
    # encoder - decoder

    def _run_model():

        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
        if self.args.output_attention:
            outputs = outputs[0]
        return outputs

batch_y变dec_inp
假如是Vanilla Transformer模型
输入的对应关系如下

 x_enc			= batch_x
 x_mark_enc 	= batch_x_mark
 x_dec 			= dec_inp
 x_mark_dec 	= batch_y_mark