为了提高效率,模型在处理句子(embeddings)时并不是一句一句进行的,而是以 batch 为单位批量处理。但一个 batch 内的句子大概率是不等长的,为了方便起见,我们会向其中填充 pad token,使得其长度一致。这也是 cs224n Assignment 4 中 utils.py
内的 pad_sents
函数所做的事情。
此时,在每个 time step 内,我们可以读取一列单词来分别进行处理。但是同时,我们不希望填充进来的 pad token 对模型产生影响,所以就有了 pack_padded_sequence 。
pack_padded_sequence
先看一下这个函数的使用:
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1,2,0],
[3,0,0],
[4,5,6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
假设我们现在处理一个 batch seq
,每一行是一个句子,其中的 0 是 pad token 。
我们可以在 time step = 1
时,读入第一列的三个数进行处理,以此类推,经过 3 个 time step 处理完毕。
但这样一来,我们把 pad token 也一起读入了。为了避免这种情况,使用 pack_padded_sequence
函数,得到 packed
。
此时 packed
类型为 PackedSequence
,其中的 data
是一个一维张量。不难看出,其顺序正是每个 time step 读入的数据顺序(按句子长度排序,按列读入),但是去除了 0。
将 seq
按句子长度排序得到:
>>> seq = torch.tensor([[4,5,6],
[1,2,0],
[3,0,0]
])
记原始行下标为 [0, 1, 2]
,那么排序后的行下标是 [2, 0, 1]
,也就是 sorted_indices
的值。
按 3 个 time step 顺序模拟一遍:
- 读入第一列
[4, 1, 3]
- 读入第二列
[5, 2]
- 读入第三列
[6]
三次数字拼起来就是 data
值。而 sorted_indices
说明了每次读入值在原始 batch 中的顺序。
所以,通过 pack_padded_sequence
返回的 PackedSequence
,我们可以在不受 pad token 干扰的情况下处理 batch 内的数据。
需要注意的是,我们通常都会先将 seq
按长度从大到小排序后再使用 pack_padded_sequence
,否则会报错。
pad_packed_sequence
显然,对 PackedSequence
进行处理产生的输出,也会是与其一样的格式,所以我们还需要一个函数来将其还原成原来的样子,这就是 pad_packed_sequence
的功能。
下面的例子足够说明其作用:
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
[3, 0, 0],
[4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])
Comments NOTHING