pack_padded_sequence 与 pad_packed_sequence

发布于 2023-02-06  133 次阅读


为了提高效率,模型在处理句子(embeddings)时并不是一句一句进行的,而是以 batch 为单位批量处理。但一个 batch 内的句子大概率是不等长的,为了方便起见,我们会向其中填充 pad token,使得其长度一致。这也是 cs224n Assignment 4 中 utils.py 内的 pad_sents 函数所做的事情。

此时,在每个 time step 内,我们可以读取一列单词来分别进行处理。但是同时,我们不希望填充进来的 pad token 对模型产生影响,所以就有了 pack_padded_sequence 。

pack_padded_sequence

先看一下这个函数的使用:

>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1,2,0],
                        [3,0,0],
                        [4,5,6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
               sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))

假设我们现在处理一个 batch seq ,每一行是一个句子,其中的 0 是 pad token 。

我们可以在 time step = 1 时,读入第一列的三个数进行处理,以此类推,经过 3 个 time step 处理完毕。

但这样一来,我们把 pad token 也一起读入了。为了避免这种情况,使用 pack_padded_sequence 函数,得到 packed

此时 packed 类型为 PackedSequence ,其中的 data 是一个一维张量。不难看出,其顺序正是每个 time step 读入的数据顺序(按句子长度排序,按列读入),但是去除了 0。

seq 按句子长度排序得到:

>>> seq = torch.tensor([[4,5,6],
                        [1,2,0],
                        [3,0,0]
                        ])

记原始行下标为 [0, 1, 2] ,那么排序后的行下标是 [2, 0, 1] ,也就是 sorted_indices 的值。

按 3 个 time step 顺序模拟一遍:

  1. 读入第一列 [4, 1, 3]
  2. 读入第二列 [5, 2]
  3. 读入第三列 [6]

三次数字拼起来就是 data 值。而 sorted_indices 说明了每次读入值在原始 batch 中的顺序。

所以,通过 pack_padded_sequence 返回的 PackedSequence ,我们可以在不受 pad token 干扰的情况下处理 batch 内的数据。

需要注意的是,我们通常都会先将 seq 按长度从大到小排序后再使用 pack_padded_sequence ,否则会报错。

pad_packed_sequence

显然,对 PackedSequence 进行处理产生的输出,也会是与其一样的格式,所以我们还需要一个函数来将其还原成原来的样子,这就是 pad_packed_sequence 的功能。

下面的例子足够说明其作用:

>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
        [3, 0, 0],
        [4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])