论文: Attention is all you need
架构图
- 左侧 Encoder 部分的输入为源句
- 右侧 Decoder 下方的输入为目标句(比如翻译中的目标语言句子)
- 左侧 Decoder 在 Softmax 前的那个 Linear 层,即为所谓的
lm_head
,它将隐层转换为词表上的概率分布。
实现
技巧
- 使用
torch.einsum
来做高维的张量运算 - 使用
masked_fill
以根据掩码对张量进行填充
SelfAttention
forward
- softmax 时是沿着
key_len
这一维度,也就是说所有列值相加为 1。 - 在注意力机制中,
key_len
和value_len
不一定相等,而在自注意力机制中总是相等的。 - 多头计算完成后的 concatenate 可以直接用
reshape
完成。
code
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.heads * self.head_dim == self.embed_size), "Embed size must be divisible by heads"
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(self.heads * self.head_dim, self.embed_size)
def forward(self, queries, keys, values, mask=None):
N = queries.shape[0]
query_len, key_len, value_len = queries.shape[1], keys.shape[1], values.shape[1]
# split embeddings into heads pieces
queries.reshape(N, query_len, self.heads, self.head_dim)
keys.reshape(N, key_len, self.heads, self.head_dim)
values.reshape(N, value_len, self.heads, self.head_dim)
queries = self.queries(queries)
keys = self.keys(keys)
values = self.values(values)
# queries shape: (N, query_len, heads, head_dim)
# keys shape: (N, key_len, heads, head_dim)
# energy shape: (N, heads, query_len, key_len)
energy = torch.einsum('nqhd,nkhd->nhqk', [queries, keys])
# fill according to mask
if mask is not None:
energy = energy.masked_fill(mask==0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
# attention shape: (N, heads, query_len, key_len)
# values shape: (N, value_len, heads, head_dim)
# out shape: (N, query_len, heads, head_dim)
# This is the raw query result,
# which should has the same shape with queries.
out = torch.einsum('nhqk,nvhd->nqhd', [attention, values])
# concatenate out, mapping by fc_out and return
out = out.reshape(N, query_len, self.heads * self.head_dim)
out = self.fc_out(out)
return out
TransformerBlock
__init__
- 两次 norm 操作对应两个 norm
- 前馈神经网络层先扩展输入(升维,倍数为
forward_expansion
),应用激活函数,再降维。扩展维度的过程允许前馈神经网络层更充分地对输入表示进行非线性变换,并学习更丰富的特征表示。这有助于提高模型的表示能力和表达能力,使其能够更好地捕捉输入序列中的关系和语义信息。 - 最后还需要一个 dropout 层
forward
按顺序进行计算即可。
code
class TransformerBlock(nn.Module):
def __init__(self,
embed_size,
heads,
forward_expansion,
dropout,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.attention = SelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size)
)
def forward(self, query, key, value, mask=None):
attention = self.attention(query, key, value, mask)
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
Encoder
__init__
初始化 word_embedding
、position_embedding
、layers
。分别用于将输入文本转换为词向量、将位置信息转换为位置向量,以及若干 TransformerBlock
。
forward
转换输入文本和位置信息,相加后做 dropout,再喂入 layers,最后输出。
code
class Encoder(nn.Module):
def __init__(self,
src_vocab_size,
embed_size,
heads,
forward_expansion,
dropout,
num_layers,
max_length,
device,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.src_vocab_size = src_vocab_size
self.embed_size = embed_size
self.heads = heads
self.forward_expansion = forward_expansion
self.num_layers = num_layers
self.max_length = max_length
self.device = device
self.dropout = nn.Dropout(dropout)
self.word_embedding = nn.Embedding(self.src_vocab_size, self.embed_size)
self.position_embedding = nn.Embedding(self.max_length, self.embed_size)
self.layers = nn.ModuleList(
[TransformerBlock(embed_size, heads, forward_expansion, dropout)
for _ in range(self.num_layers)]
)
def forward(self, x, mask):
N, seq_length = x.shape
positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
for layer in self.layers:
out = layer(out, out, out, mask)
return out
DecoderBlock
结构上只是比 TransformerBlock
多了一个 SelfAttention
和 norm,因此在其基础上添加即可。
这里会有两个 mask,对应 target 和 source。
由 Encoder 部分产生 key 和 value,而 query 由输入的 x 得到,三者连同 src_mask
一同喂入 TansformerBlock
即得到输出。注意到,key 和 value 其实是相等的,都是 Encoder 部分产生的一个输出。
code
class DecoderBlock(nn.Module):
def __init__(self,
embed_size,
heads,
forward_expansion,
dropout,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.embed_size = embed_size
self.heads = heads
self.forward_expansion = forward_expansion
self.attention = SelfAttention(embed_size, heads)
self.transformer_block = TransformerBlock(embed_size, heads, forward_expansion, dropout)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(embed_size)
def forward(self, x, key, value, trg_mask, src_mask=None):
attention = self.attention(x, x, x, trg_mask)
query = self.dropout(self.norm(attention + x))
out = self.transformer_block(query, key, value, src_mask)
return out
Decoder
类似 Encoder
组装即可,注意到要用到 Encoder
的输出作为输入。
code
class Decoder(nn.Module):
def __init__(self,
embed_size,
heads,
forward_expansion,
trg_vocab_size,
dropout,
max_length,
num_layers,
device,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.embed_size = embed_size
self.heads = heads
self.forward_expansion = forward_expansion
self.trg_vocab_size = trg_vocab_size
self.max_length = max_length
self.num_layers = num_layers
self.device = device
self.word_embedding = nn.Embedding(self.trg_vocab_size, self.embed_size)
self.position_embedding = nn.Embedding(self.max_length, self.embed_size)
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList(
[DecoderBlock(self.embed_size, self.heads, self.forward_expansion, dropout)
for _ in range(self.num_layers)]
)
self.fc_out = nn.Linear(embed_size, self.trg_vocab_size)
def forward(self, x, enc_out, trg_mask, src_mask=None):
N, seq_length = x.shape
positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
for layer in self.layers:
out = layer(out, enc_out, enc_out, trg_mask, src_mask)
out = self.fc_out(out)
return out
Transformer
code
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size,
trg_vocab_size,
src_pad_idx,
trg_pad_idx,
embed_size=256,
heads=8,
forward_expansion=4,
num_layers=6,
dropout=0,
device='cuda',
max_length=128
):
super().__init__()
self.src_vocab_size = src_vocab_size
self.trg_vocab_size = trg_vocab_size
self.src_pad_idx = src_pad_idx
self.trg_pad_idx = trg_pad_idx
self.embed_size = embed_size
self.heads = heads
self.forward_expansion = forward_expansion
self.num_layers = num_layers
self.device = device
self.max_length = max_length
self.encoder = Encoder(self.src_vocab_size, self.embed_size, self.heads, self.forward_expansion, dropout, self.num_layers, self.max_length, self.device)
self.decoder = Decoder(self.embed_size, self.heads, self.forward_expansion, self.trg_vocab_size, dropout, self.max_length, self.num_layers, self.device)
def make_src_mask(self, src):
# src_mask shape: (N, 1, 1, src_len)
src_mask = (src != self.src_pad_idx).unsqueeze(0).unsqueeze(1)
return src_mask.to(self.device)
def make_trg_mask(self, trg):
N, trg_len = trg.shap
# trg_mask shape: (N, 1, trg_len, trg_len)
trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)
return trg_mask.to(self.device)
def forward(self, src, trg):
src_mask = self.make_src_mask(src)
trg_mask = self.make_trg_mask(trg)
enc_out = self.encoder(src, src_mask)
dec_out = self.decoder(trg, enc_out, trg_mask, src_mask)
return dec_out.to(self.device)
Comments NOTHING