Pytorch中的变长序列处理

在使用循环神经网络建模的时候,经常遇到要处理变长序列的情况,如:处理长度不同的自然语句。这种情况下会遇到两个问题:

1.RNN的长度应该是最长序列的长度,但是对于短序列而言,输入比RNN的长度更短

2.计算后应该以每个序列长度的最后一个RNN cell输出的hidden作为结果,而不是整个RNN最后一个cell输出的hidden

Figure 1

为解决第一个问题,我们只需要对序列进行padding操作,在长度小于最大序列长度的序列末尾补0(PAD)。

为解决第二个问题,pytorch中的nn.utils.rnn模块中提供了pack_padded_sequence和pad_packed_sequence两个函数。函数名应该理解为pack padded_sequence和pad packed_sequence。

在进行pack_pad操作时,要先将序列的长度从大到小排列,然后生成一个包含每个序列长度的tensor:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 创建一个训练集
inputs = [list(range(1, i+1)) for i in range(1, 6)]
random.shuffle(inputs)
print("原始的inputs为:\n", inputs)

inputs.sort(key=lambda x: len(x), reverse=True)
print("处理过的inputs为:\n",inputs)

length = [len(sublist) for sublist in inputs]
print("每个序列的长度为:\n",length)

maxLen = max(length)
inputs = [sublist+[0]*(maxLen-len(sublist)) for sublist in inputs]
print("padding处理后的inputs为:\n", inputs)

output:
原始的inputs为:
[[1, 2], [1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]
处理过的inputs为:
[[1, 2, 3, 4, 5], [1, 2, 3, 4], [1, 2, 3], [1, 2], [1]]
每个序列的长度为:
[5, 4, 3, 2, 1]
padding处理后的inputs为:
[[1, 2, 3, 4, 5], [1, 2, 3, 4, 0], [1, 2, 3, 0, 0], [1, 2, 0, 0, 0], [1, 0, 0, 0, 0]]

调用pack_padded_sequence:

1
2
3
4
5
6
7
8
inputs = torch.FloatTensor(inputs)
length = torch.LongTensor(length)
pack_inputs = nn.utils.rnn.pack_padded_sequence(inputs, length)
print("pack后的inputs:\n", pack_inputs)

output:
pack后的inputs:
PackedSequence(data=tensor([1, 2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 1, 2, 1]), batch_sizes=tensor([5, 4, 3, 2, 1]))

可以看到,在pack_padded_sequence后生成一个PackedSequence对象,包含flatten后的原始序列和每个序列的长度。pytorch在这个操作中执行了下图中的步骤。

Figure 2

调用pad_packed_sequence:

1
2
3
4
5
6
7
8
9
10
11
12
13
pad_inputs, pad_len = nn.utils.rnn.pad_packed_sequence(pack_inputs)
print("pad后的inputs:\n", pad_inputs)
print("pad输出的len:\n", pad_len)

outputs:
pad后的inputs:
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 0],
[1, 2, 3, 0, 0],
[1, 2, 0, 0, 0],
[1, 0, 0, 0, 0]])
pad输出的len:
tensor([5, 4, 3, 2, 1])

该函数返回两个参数,即padded的序列和每个序列的长度tensor。

在pytorch的所有种类的RNN中,只要inputs是PackedSequence对象,那么outputs也是一个PackedSequence对象(注意,不是hidden或context,这两个输出仍然是Tensor)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 先对inputs进行处理
inputs = inputs.unsqueeze(2)
print(inputs.shape)
inputs = nn.utils.rnn.pack_padded_sequence(inputs, length)
model = nn.LSTM(1, 15, 3, batch_first=True, bidirectional=True)
outputs, (hidden, context) = model(inputs)
print("output's type:\n", type(outputs))
print("hidden's type:\n", type(hidden))
print("context's type:\n", type(context))

outputs:
torch.Size([5, 5, 1])
output's type:
<class 'torch.nn.utils.rnn.PackedSequence'>
hidden's type:
<class 'torch.Tensor'>
context's type:
<class 'torch.Tensor'>

这时候要进一步处理outputs,需要调用pad_packed_sequence:

1
2
3
4
5
6
7
8
9
10
11
12
13
outputs, length = nn.utils.rnn.pad_packed_sequence(outputs)

print("output's type:\n", type(outputs))
print("output's dimention:\n", outputs.shape)
print("length tensor:\n", length)

outputs:
output's type:
<class 'torch.Tensor'>
output's dimention:
torch.Size([5, 5, 30])
length tensor:
tensor([5, 4, 3, 2, 1])

该例子中batch_size = 5,hidden_size = 15,seq_len = 5,开启了batch_first,使用的是双向网络,因此outputs维度为[5, 5, 2 × 15 = 30]。

还有一个细节,outputs的长度与序列的长度一一对应,然而pad后的outputs的长度与最长序列的长度相同。那其他序列多余的部分又是什么呢?这里将outputs打印出来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
print(outputs)

outputs:
tensor([[[-4.0682e-02, 8.3953e-02, -3.9291e-02, 3.0797e-02, 1.4554e-02,
1.1384e-01, 4.2089e-02, 2.5681e-02, 4.7258e-02, 4.8985e-02,
-3.6537e-03, -3.3367e-02, 2.4790e-02, -1.4111e-02, -1.7248e-02,
-6.7965e-02, 1.0308e-01, -1.2599e-01, -9.9059e-02, 1.7291e-01,
-1.8176e-01, 1.0104e-01, 1.0875e-01, -4.0975e-03, -3.4452e-02,
9.5266e-02, -6.9321e-02, -3.3464e-03, 8.6780e-03, -1.0240e-01],
[-4.2807e-02, 8.0022e-02, -3.9385e-02, 3.3965e-02, 1.5756e-02,
1.1501e-01, 4.0348e-02, 2.6273e-02, 4.7790e-02, 5.0116e-02,
3.0397e-03, -3.4664e-02, 2.6208e-02, -1.2610e-02, -1.9977e-02,
-6.5606e-02, 1.1538e-01, -1.2242e-01, -1.0063e-01, 1.6027e-01,
-1.6615e-01, 9.5598e-02, 1.0820e-01, -2.6833e-04, -3.2401e-02,
9.4688e-02, -7.4737e-02, -8.7567e-03, -1.4937e-03, -9.9632e-02],
[-4.5313e-02, 7.8306e-02, -3.7688e-02, 3.6819e-02, 1.6507e-02,
1.1540e-01, 4.0500e-02, 2.4949e-02, 4.8806e-02, 4.9541e-02,
8.8827e-03, -3.4783e-02, 2.6201e-02, -1.0008e-02, -2.2865e-02,
-6.0985e-02, 1.1900e-01, -1.1312e-01, -9.9706e-02, 1.4205e-01,
-1.3885e-01, 8.6394e-02, 1.0617e-01, -2.5593e-04, -2.3208e-02,
8.9718e-02, -7.3640e-02, -1.3402e-02, -7.4062e-03, -9.3399e-02],
[-4.7680e-02, 7.9925e-02, -3.3929e-02, 3.8575e-02, 1.6410e-02,
1.1501e-01, 4.2421e-02, 2.0348e-02, 5.0167e-02, 4.8087e-02,
1.4561e-02, -3.3355e-02, 2.4420e-02, -6.1677e-03, -2.5104e-02,
-5.2512e-02, 1.0933e-01, -9.2953e-02, -9.2083e-02, 1.1344e-01,
-1.0036e-01, 7.0717e-02, 9.7617e-02, -4.6024e-03, -9.8933e-03,
7.6389e-02, -6.2963e-02, -1.5895e-02, -7.7886e-03, -8.1531e-02],
[-5.0099e-02, 8.6489e-02, -2.8123e-02, 3.8329e-02, 1.5578e-02,
1.1394e-01, 4.5894e-02, 1.0719e-02, 5.1936e-02, 4.6758e-02,
2.1022e-02, -2.9532e-02, 2.0833e-02, -8.2043e-04, -2.5985e-02,
-3.5800e-02, 7.6206e-02, -5.5761e-02, -6.7257e-02, 6.7798e-02,
-5.3698e-02, 4.4487e-02, 7.0936e-02, -9.9438e-03, 1.3279e-03,
4.8148e-02, -3.9513e-02, -1.3307e-02, -2.4901e-03, -5.6968e-02]],

[[-5.9009e-02, 1.1090e-01, -6.0885e-02, 4.2813e-02, 2.8173e-02,
1.5819e-01, 7.3094e-02, 5.3410e-02, 7.4454e-02, 7.8682e-02,
-1.0169e-02, -5.0618e-02, 3.0245e-02, -1.5371e-02, -2.9641e-02,
-6.8304e-02, 1.0834e-01, -1.1980e-01, -1.0181e-01, 1.6701e-01,
-1.6799e-01, 9.6521e-02, 1.0703e-01, -9.8389e-03, -2.5798e-02,
9.1676e-02, -6.7476e-02, -8.5891e-03, 3.1723e-03, -1.0151e-01],
[-6.3383e-02, 1.0464e-01, -5.9683e-02, 4.9038e-02, 2.9302e-02,
1.6012e-01, 7.0652e-02, 5.2944e-02, 7.6386e-02, 8.0110e-02,
2.7956e-03, -5.2098e-02, 3.1982e-02, -1.1622e-02, -3.4363e-02,
-6.3715e-02, 1.1809e-01, -1.1284e-01, -1.0229e-01, 1.4737e-01,
-1.4279e-01, 8.5988e-02, 1.0390e-01, -8.6067e-03, -2.1914e-02,
8.8866e-02, -6.9284e-02, -1.3097e-02, -7.6538e-03, -9.4410e-02],
[-6.8404e-02, 1.0299e-01, -5.5642e-02, 5.3935e-02, 2.9866e-02,
1.6053e-01, 7.1565e-02, 4.7939e-02, 7.8676e-02, 7.8942e-02,
1.4926e-02, -5.1489e-02, 3.1048e-02, -6.0752e-03, -3.8717e-02,
-5.4326e-02, 1.1279e-01, -9.5235e-02, -9.5720e-02, 1.1671e-01,
-1.0354e-01, 6.9420e-02, 9.5840e-02, -1.1347e-02, -1.1086e-02,
7.6809e-02, -6.2138e-02, -1.6182e-02, -1.1471e-02, -8.1150e-02],
[-7.3133e-02, 1.0863e-01, -4.8284e-02, 5.5917e-02, 2.9400e-02,
1.5938e-01, 7.5490e-02, 3.5192e-02, 8.1008e-02, 7.7179e-02,
2.6736e-02, -4.7263e-02, 2.7241e-02, 1.1862e-03, -4.0976e-02,
-3.6420e-02, 8.1268e-02, -5.9056e-02, -7.0855e-02, 6.8678e-02,
-5.3927e-02, 4.2728e-02, 7.0533e-02, -1.4319e-02, 3.2542e-04,
4.9083e-02, -4.1215e-02, -1.4146e-02, -6.8797e-03, -5.5682e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

[[-6.8628e-02, 1.2067e-01, -6.9671e-02, 4.8275e-02, 3.8466e-02,
1.7669e-01, 9.4873e-02, 7.2925e-02, 9.2304e-02, 9.4618e-02,
-1.2577e-02, -5.9884e-02, 2.7681e-02, -1.2495e-02, -3.7470e-02,
-6.5208e-02, 1.0843e-01, -1.0809e-01, -1.0186e-01, 1.5302e-01,
-1.4540e-01, 8.8472e-02, 1.0491e-01, -1.5388e-02, -1.4619e-02,
8.4476e-02, -6.1932e-02, -1.3132e-02, -1.8400e-04, -9.6829e-02],
[-7.5361e-02, 1.1422e-01, -6.6453e-02, 5.6688e-02, 3.8909e-02,
1.7890e-01, 9.2584e-02, 6.9072e-02, 9.5968e-02, 9.5608e-02,
6.4508e-03, -6.0464e-02, 2.8626e-02, -5.9382e-03, -4.3329e-02,
-5.5938e-02, 1.0904e-01, -9.3068e-02, -9.6616e-02, 1.2118e-01,
-1.0824e-01, 7.0649e-02, 9.5562e-02, -1.7137e-02, -9.0465e-03,
7.4589e-02, -5.7249e-02, -1.5995e-02, -8.6256e-03, -8.2399e-02],
[-8.2987e-02, 1.1563e-01, -5.9453e-02, 6.1940e-02, 3.8952e-02,
1.7882e-01, 9.4832e-02, 5.6927e-02, 9.9607e-02, 9.4224e-02,
2.4712e-02, -5.7339e-02, 2.6087e-02, 2.5044e-03, -4.7602e-02,
-3.7186e-02, 8.1333e-02, -5.9543e-02, -7.2458e-02, 7.0927e-02,
-5.7131e-02, 4.2974e-02, 7.0481e-02, -1.8356e-02, 4.7081e-05,
4.8462e-02, -3.9545e-02, -1.4449e-02, -7.0511e-03, -5.5383e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

[[-7.5215e-02, 1.2570e-01, -7.0803e-02, 5.1324e-02, 4.5722e-02,
1.8440e-01, 1.1057e-01, 8.2077e-02, 1.0470e-01, 1.0276e-01,
-9.1051e-03, -6.4096e-02, 2.2053e-02, -6.9874e-03, -4.2566e-02,
-5.7048e-02, 1.0045e-01, -8.6925e-02, -9.5618e-02, 1.2530e-01,
-1.1105e-01, 7.3675e-02, 9.7405e-02, -2.1009e-02, -2.4056e-03,
6.9669e-02, -5.1597e-02, -1.6832e-02, -1.2028e-03, -8.5562e-02],
[-8.4856e-02, 1.2209e-01, -6.5173e-02, 6.0102e-02, 4.5486e-02,
1.8622e-01, 1.0906e-01, 7.1573e-02, 1.1011e-01, 1.0305e-01,
1.5899e-02, -6.1909e-02, 2.1459e-02, 2.8874e-03, -4.8808e-02,
-3.8185e-02, 7.8297e-02, -5.6673e-02, -7.2678e-02, 7.3505e-02,
-6.0715e-02, 4.4512e-02, 7.1230e-02, -2.1180e-02, 1.8828e-03,
4.5976e-02, -3.6372e-02, -1.4772e-02, -4.1149e-03, -5.6574e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

[[-8.2753e-02, 1.3126e-01, -6.8009e-02, 5.2415e-02, 5.1062e-02,
1.8679e-01, 1.2218e-01, 8.0173e-02, 1.1368e-01, 1.0678e-01,
1.2051e-03, -6.3800e-02, 1.5488e-02, 2.2219e-03, -4.7019e-02,
-3.9210e-02, 7.3067e-02, -5.1220e-02, -7.1830e-02, 7.5660e-02,
-6.2923e-02, 4.6912e-02, 7.2609e-02, -2.2428e-02, 6.0529e-03,
4.1805e-02, -3.3265e-02, -1.5978e-02, 8.6623e-04, -5.9381e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]],
grad_fn=<CopySlices>)

可以看到,PAD的输入对应的output为0向量。因此,对于不同长度的序列,只需按照序列的长度截取output进行处理即可。

注:

1.如果序列没有按照长度从大到小排列,会报错:

1
RuntimeError: 'lengths' array has to be sorted in decreasing order

2.PAD不一定要设置成0,实际上这个函数是根据length数组,将超过length部分的数据略去。因此,在NLP中,可以先embedding再pack,embedding将PAD转换为向量不会影响pack的使用。

文章作者: 地瓜
文章链接: https://www.l-zhe.com/2019/08/09/Pytorch中的变长序列处理/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 人参地里的地瓜
打赏
  • 微信
  • 支付宝

评论
目录