更新知识地图,拓展认知边界

多头交叉Attention机制的代码实现

怎么搞

在Encoder-Decoder模型中,Attention机制可以将编码器的输出和解码器的隐藏状态联系起来


我们需要计算的是解码器输出对于编码器输出的注意力(代码里的编码器输出指的是RNN的output,不是最后一层hidden)

对于编解码交叉注意力,我们需要将Context(Encoder输出)看成是K和V,将Decoder的输出看成是Q,再按照多头注意力的方法进行处理就可以了

代码

编解码Attention

class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()

    def forward(self, decoder_context, encoder_context):
        # 计算decoder_context和encoder_context的点积,得到注意力分数
        scores = torch.matmul(decoder_context, encoder_context.transpose(-2, -1))
        # 归一化分数
        attn_weights = nn.functional.softmax(scores, dim=-1)
        # 将注意力权重乘以encoder_context,得到加权的上下文向量
        context = torch.matmul(attn_weights, encoder_context)
        return context, attn_weights

多头交叉Attention

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        #QKV线性变换
        self.linear_q = torch.nn.Linear(128, 128)
        self.linear_k = torch.nn.Linear(128, 128)
        self.linear_v = torch.nn.Linear(128, 128)
        # 最终拼接接结果的线性变换层
        self.linear_out = torch.nn.Linear(128, 128)
    
    def split_heads(self, tensor, num_heads):
        batch_size, seq_len, feature_dim = tensor.size()
        head_dim = feature_dim // num_heads
        return tensor.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
    
    def combine_heads(self, tensor, num_heads):
        batch_size, num_heads, seq_len, head_dim = tensor.size()
        feature_dim = num_heads * head_dim
        return tensor.transpose(1, 2).contiguous().view(batch_size, seq_len, feature_dim)

    def forward(self, decoder_context, encoder_context):
        # 此处对decoder_context进行多头交叉注意力计算

        # 定义头数,目前的隐藏层数量为128,此处设置头数为8
        num_heads = 8
        # 128 / 8 = 16
        
        # 此处我们需要使得KV同源,为encoder_context
        K = self.linear_q(encoder_context) # (batch_size, seq_len, Q_dim=128)
        V = self.linear_k(encoder_context) # (batch_size, seq_len, Q_dim=128)
        # Q则为decoder_context
        Q = self.linear_v(decoder_context) # (batch_size, seq_len, Q_dim=128)

        Q = self.split_heads(Q, num_heads) # (batch_size, num_heads=8, seq_len, head_dim=16)
        K = self.split_heads(K, num_heads)
        V = self.split_heads(V, num_heads)

        # 计算 Q 和 K 的点积,作为相似度分数,也就是自注意力原始权重
        raw_weights = torch.matmul(Q, K.transpose(-2, -1)) # (batch_size, num_heads, seq_len, seq_len)

        # 自注意力原始权重进行缩放
        scale_factor = K.size(-1) ** 0.5
        scaled_weights = raw_weights / scale_factor # (batch_size, num_heads, seq_len, seq_len)

        # 对缩放后的权重进行 softmax 归一化,得到注意力权重
        attn_weights = F.softmax(scaled_weights, dim=-1)  # (batch_size, num_heads, seq_len, seq_len)

        # print(attn_weights.shape)
        # print(V.shape)

        # 将注意力权重应用于 V 向量,计算加权和,得到加权信息
        attn_outputs = torch.matmul(attn_weights, V) # (batch_size, num_heads=8, seq_len, head_dim=16)

        # 把头拼接起来
        attn_outputs = self.combine_heads(attn_outputs, num_heads)  # (batch_size, seq_len, feature_dim=128)

        # 最终做个线性变换
        context = self.linear_out(attn_outputs) # (batch_size, seq_len, output_dim=128)
        
        # print(scores.shape,context.shape)
        return context, attn_weights

参考

深蓝学院代码
论文Attention Is All You Need

多头交叉Attention机制的代码实现

https://cyberyang.com/NLP/3.html

作者

chen

发布时间

2023-12-28

许可协议

CC BY 4.0

本页的评论功能已关闭