注意力机制

通过tensorflow手撕注意力机制
深度学习
Tensorflow
Author

Hahabula

Published

2025-05-20

Modified

2025-05-20

1 注意力机制

1.1 核心思想

注意力机制模拟人类视觉或认知系统在处理信息时会“关注”关键部分的能力。在模型中,它可以让网络自动学习“应该关注哪些部分的输入”。

1.2 常见形式

最常用的形式是缩放点积注意力:

\[ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V \]

其中,\(Q\)\(K\)\(V\) 分别是查询、键和值向量,\(d_k\)是键的维度。

讨论
  1. 查询、键、值向量的获取

    通过一个全连接层将某个序列映射成相应的向量,例如:

\[ Q = \text{FC}(x_1)=Wx_1 \]

  1. 缩放点积注意力

    为啥用 \(\frac{QK'}{\sqrt{d_k}}\) 是因为点积的结果可能会过大,导致softmax的计算结果不准确。因此,缩放点积注意力的目的是将点积结果缩放到一个合适的范围。

  2. 注意力机制的分类

  • 自注意力:查询、键和值向量来自同一个序列。
  • 交叉注意力:查询来自A,键、值来自序列B。
  • 统计建模中的注意力:查询来自B,键来自A,值来自concat[A,B],tf.concat([A,B], axis = -1)表示将A和B按最后一维拼接。
  • 多头注意力机制:对键,值,查询的维度分为 \(h\) 份,然后分别计算注意力分数,最后再拼接并接上一个全连接层。

多头注意力机制示意图

多头注意力机制示意图 1

2 Tensorflow实现

import tensorflow as tf

class CustomAttention(tf.keras.layers.Layer):
    def __init__(self, d_k, d_v):
        super().__init__()
        self.d_k = d_k  # 键/查询维度
        self.d_v = d_v  # 值维度

    def build(self, input_shape):
        self.W_q = self.add_weight(
            name = "W_q",
            shape = (input_shape[-1], self.d_k),
            initializer = "glorot_uniform",
            trainable = True
        )  # 查询矩阵

        self.W_k = self.add_weight(
            name = "W_k",
            shape = (input_shape[-1], self.d_k),
            initializer = "glorot_uniform",
            trainable = True
        )  # 键矩阵

        self.W_v = self.add_weight(
            name = "W_v",
            shape = (input_shape[-1], self.d_v),
            initializer = "glorot_uniform",
            trainable = True
1        )

    def call(self, inputs):
2        Q = tf.matmul(inputs, self.W_q)
        K = tf.matmul(inputs, self.W_k)  # [batch_size, seq_len, d_k]
        V = tf.matmul(inputs, self.W_v)  # [batch_size, seq_len, d_v]

        # 计算注意力分数
        scores = tf.matmul(Q, K, transpose_b = True)  # [batch_size, seq_len, seq_len]
3        scaled_scores = scores / tf.math.sqrt(tf.cast(self.d_k, tf.float32))

        # 应用softmax
4        attention_weights = tf.nn.softmax(scaled_scores, axis = -1)  # [batch_size, seq_len, seq_len]

        # 上下文向量
        context = tf.matmul(attention_weights, V)  # [bath_size, seq_len, d_v]
        return context

# 示例
input_layer = tf.keras.Input(shape=(10, 64))  # seq_len = 10, dim = 64
5attention_output = CustomAttention(d_k=128, d_v=256)(input_layer)
model = tf.keras.models.Model(inputs=input_layer, outputs=attention_output)
print(model.summary())
1
自定义权重的方式self.W = tf.add_weight(name = "W", shape = (input_shape[-1], self.d_v), initializer = "glorot_uniform", trainable = True)(名字->shape->初始化器->可训练性)
2
矩阵乘法的计算tf.matmul(inputs, self.W_q, transpose_b = True): 其中transpose_b = True表示将第二个矩阵转置后再进行乘法运算。
3
平方根tf.math.sqrt(): 用于计算张量的平方根;数据转换tf.cast(): 用于将张量的数据类型转换为另一种类型;经过计算后的维度为[batch_size, seq_len, seq_len],由计算的方式得最后一维的seq_len表示每个词对其他词的注意力权重。因此需要在后续的softmax中的计算中对最后一个维度进行计算。
4
tf.nn.softmax(): 用于计算张量的softmax值。axis = -1表示在最后一个维度上进行softmax计算。
5
CustomAttention()(input_layer): 调用自定义的注意力层。形成层级关系
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 10, 64)]          0         
                                                                 
 custom_attention (CustomAtt  (None, 10, 256)          32768     
 ention)                                                         
                                                                 
=================================================================
Total params: 32,768
Trainable params: 32,768
Non-trainable params: 0
_________________________________________________________________
None
import tensorflow as tf

class CustomAttention(tf.keras.layers.Layer):
    def __init__(self, d_k, d_v):
        super().__init__()
        self.d_k = d_k  # 键/查询维度
        self.d_v = d_v  # 值维度
1        self.W_k = tf.keras.layers.Dense(d_k, use_bias=False)
        self.W_q = tf.keras.layers.Dense(d_k, use_bias=False)
        self.W_v = tf.keras.layers.Dense(d_v, use_bias=False)

    def call(self, inputs):
        Q = self.W_q(inputs)  # [batch_size, seq_len, d_k]
        K = self.W_k(inputs)  # [batch_size, seq_len, d_k]
        V = self.W_v(inputs)  # [batch_size, seq_len, d_v]

        # 计算注意力分数
        scoress = tf.matmul(Q, K, transpose_b=True)  # [b, s, s]
        scaled_scores = scoress / tf.math.sqrt(tf.cast(self.d_k, tf.float32))  # [b, s, s]

        # 应用softmax
        attention_weights = tf.nn.softmax(scaled_scores, axis=-1)  # [b, s, s]

        # 上下文向量
        context = tf.matmul(attention_weights, V)  # [b, s, d_v]
        return context

# 示例
input_layer = tf.keras.Input(shape=(10, 64))  # seq_len = 10, dim = 64
attention_output = CustomAttention(d_k=128, d_v=256)(input_layer)  
model = tf.keras.models.Model(inputs=input_layer, outputs=attention_output)
print(model.summary())
1
记得在自定义层中定义全连接层,并将use_bias=False设置为不使用偏置。
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 10, 64)]          0         
                                                                 
 custom_attention_1 (CustomA  (None, 10, 256)          32768     
 ttention)                                                       
                                                                 
=================================================================
Total params: 32,768
Trainable params: 32,768
Non-trainable params: 0
_________________________________________________________________
None
import tensorflow as tf

# 定义输入(假设序列长度为10,特征维度64)
query = tf.random.normal((32, 10, 64))  # batch_size=32
key = tf.random.normal((32, 10, 64))
value = tf.random.normal((32, 10, 64))

# 使用内置 Attention 层
attention_layer = tf.keras.layers.Attention()
output = attention_layer([query, value], mask=None)  # 自动计算 key=query
print(output.shape)  # (32, 10, 64)
(32, 10, 64)
Back to top

Footnotes

  1. https://epjournal.csee.org.cn/ncdqh/supplement/a566e005-3b84-42b3-921b-11949269b5d5↩︎