Pytorch Transformers from Scratch

发布于 2023-07-10  135 次阅读


论文: Attention is all you need

视频: https://www.youtube.com/watch?v=U0s0f995w14 

架构图

  • 左侧 Encoder 部分的输入为源句
  • 右侧 Decoder 下方的输入为目标句(比如翻译中的目标语言句子)
  • 左侧 Decoder 在 Softmax 前的那个 Linear 层,即为所谓的 lm_head ,它将隐层转换为词表上的概率分布。

实现

技巧

  • 使用 torch.einsum 来做高维的张量运算
  • 使用 masked_fill 以根据掩码对张量进行填充

SelfAttention

forward

  • softmax 时是沿着 key_len 这一维度,也就是说所有列值相加为 1。
  • 注意力机制中,key_lenvalue_len 不一定相等,而在自注意力机制中总是相等的。
  • 多头计算完成后的 concatenate 可以直接用 reshape 完成。

code

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        assert (self.heads * self.head_dim == self.embed_size), "Embed size must be divisible by heads"

        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(self.heads * self.head_dim, self.embed_size)

    def forward(self, queries, keys, values, mask=None):
        N = queries.shape[0]
        query_len, key_len, value_len = queries.shape[1], keys.shape[1], values.shape[1]

        # split embeddings into heads pieces
        queries.reshape(N, query_len, self.heads, self.head_dim)
        keys.reshape(N, key_len, self.heads, self.head_dim)
        values.reshape(N, value_len, self.heads, self.head_dim)

        queries = self.queries(queries)
        keys = self.keys(keys)
        values = self.values(values)

        # queries shape: (N, query_len, heads, head_dim)
        # keys shape: (N, key_len, heads, head_dim)
        # energy shape: (N, heads, query_len, key_len)
        energy = torch.einsum('nqhd,nkhd->nhqk', [queries, keys])

        # fill according to mask
        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # out shape: (N, query_len, heads, head_dim)
        # This is the raw query result, 
        # which should has the same shape with queries.
        out = torch.einsum('nhqk,nvhd->nqhd', [attention, values])

        # concatenate out, mapping by fc_out and return
        out = out.reshape(N, query_len, self.heads * self.head_dim)
        out = self.fc_out(out)
        return out

TransformerBlock

__init__

  • 两次 norm 操作对应两个 norm 
  • 前馈神经网络层先扩展输入(升维,倍数为 forward_expansion),应用激活函数,再降维。扩展维度的过程允许前馈神经网络层更充分地对输入表示进行非线性变换,并学习更丰富的特征表示。这有助于提高模型的表示能力和表达能力,使其能够更好地捕捉输入序列中的关系和语义信息。
  • 最后还需要一个 dropout 层

forward

按顺序进行计算即可。

code

class TransformerBlock(nn.Module):
    def __init__(self, 
                 embed_size, 
                 heads, 
                 forward_expansion, 
                 dropout, 
                 *args, 
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

    def forward(self, query, key, value, mask=None):
        attention = self.attention(query, key, value, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

Encoder

__init__

初始化 word_embeddingposition_embeddinglayers 。分别用于将输入文本转换为词向量、将位置信息转换为位置向量,以及若干 TransformerBlock

forward

转换输入文本和位置信息,相加后做 dropout,再喂入 layers,最后输出。

code

class Encoder(nn.Module):
    def __init__(self, 
                 src_vocab_size, 
                 embed_size, 
                 heads, 
                 forward_expansion, 
                 dropout, 
                 num_layers, 
                 max_length, 
                 device,
                 *args, 
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.src_vocab_size = src_vocab_size
        self.embed_size = embed_size
        self.heads = heads
        self.forward_expansion = forward_expansion
        self.num_layers = num_layers
        self.max_length = max_length
        self.device = device
        self.dropout = nn.Dropout(dropout)
        self.word_embedding = nn.Embedding(self.src_vocab_size, self.embed_size)
        self.position_embedding = nn.Embedding(self.max_length, self.embed_size)
        self.layers = nn.ModuleList(
            [TransformerBlock(embed_size, heads, forward_expansion, dropout)
             for _ in range(self.num_layers)]
        )

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        for layer in self.layers:
            out = layer(out, out, out, mask) 
        return out

DecoderBlock

结构上只是比 TransformerBlock 多了一个 SelfAttention 和 norm,因此在其基础上添加即可。

这里会有两个 mask,对应 target 和 source。

由 Encoder 部分产生 key 和 value,而 query 由输入的 x 得到,三者连同 src_mask 一同喂入 TansformerBlock 即得到输出。注意到,key 和 value 其实是相等的,都是 Encoder 部分产生的一个输出。

code

class DecoderBlock(nn.Module):
    def __init__(self, 
                 embed_size, 
                 heads, 
                 forward_expansion, 
                 dropout, 
                 *args, 
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embed_size = embed_size
        self.heads = heads
        self.forward_expansion = forward_expansion
        self.attention = SelfAttention(embed_size, heads)
        self.transformer_block = TransformerBlock(embed_size, heads, forward_expansion, dropout)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(embed_size)

    def forward(self, x, key, value, trg_mask, src_mask=None):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(query, key, value, src_mask)
        return out

Decoder

类似 Encoder 组装即可,注意到要用到 Encoder 的输出作为输入。

code

class Decoder(nn.Module):
    def __init__(self, 
                 embed_size, 
                 heads, 
                 forward_expansion, 
                 trg_vocab_size, 
                 dropout, 
                 max_length, 
                 num_layers, 
                 device, 
                 *args, 
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embed_size = embed_size
        self.heads = heads
        self.forward_expansion = forward_expansion
        self.trg_vocab_size = trg_vocab_size
        self.max_length = max_length
        self.num_layers = num_layers
        self.device = device
        self.word_embedding = nn.Embedding(self.trg_vocab_size, self.embed_size)
        self.position_embedding = nn.Embedding(self.max_length, self.embed_size)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList(
            [DecoderBlock(self.embed_size, self.heads, self.forward_expansion, dropout)
             for _ in range(self.num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, self.trg_vocab_size)

    def forward(self, x, enc_out, trg_mask, src_mask=None):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        for layer in self.layers:
            out = layer(out, enc_out, enc_out, trg_mask, src_mask)
        out = self.fc_out(out)
        return out

Transformer

code

class Transformer(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            trg_vocab_size,
            src_pad_idx,
            trg_pad_idx,
            embed_size=256,
            heads=8,
            forward_expansion=4,
            num_layers=6,
            dropout=0,
            device='cuda',
            max_length=128
    ):
        super().__init__()
        self.src_vocab_size = src_vocab_size
        self.trg_vocab_size = trg_vocab_size
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.embed_size = embed_size
        self.heads = heads
        self.forward_expansion = forward_expansion
        self.num_layers = num_layers
        self.device = device
        self.max_length = max_length
        self.encoder = Encoder(self.src_vocab_size, self.embed_size, self.heads, self.forward_expansion, dropout, self.num_layers, self.max_length, self.device)
        self.decoder = Decoder(self.embed_size, self.heads, self.forward_expansion, self.trg_vocab_size, dropout, self.max_length, self.num_layers, self.device)

    def make_src_mask(self, src):
        # src_mask shape: (N, 1, 1, src_len)
        src_mask = (src != self.src_pad_idx).unsqueeze(0).unsqueeze(1)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shap
        # trg_mask shape: (N, 1, trg_len, trg_len)
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)
        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src, src_mask)
        dec_out = self.decoder(trg, enc_out, trg_mask, src_mask)
        return dec_out.to(self.device)