import tensorflow as tf
class CustomAttention(tf.keras.layers.Layer):
def __init__(self, d_k, d_v):
super().__init__()
self.d_k = d_k # 键/查询维度
self.d_v = d_v # 值维度
def build(self, input_shape):
self.W_q = self.add_weight(
name = "W_q",
shape = (input_shape[-1], self.d_k),
initializer = "glorot_uniform",
trainable = True
) # 查询矩阵
self.W_k = self.add_weight(
name = "W_k",
shape = (input_shape[-1], self.d_k),
initializer = "glorot_uniform",
trainable = True
) # 键矩阵
self.W_v = self.add_weight(
name = "W_v",
shape = (input_shape[-1], self.d_v),
initializer = "glorot_uniform",
trainable = True
1 )
def call(self, inputs):
2 Q = tf.matmul(inputs, self.W_q)
K = tf.matmul(inputs, self.W_k) # [batch_size, seq_len, d_k]
V = tf.matmul(inputs, self.W_v) # [batch_size, seq_len, d_v]
# 计算注意力分数
scores = tf.matmul(Q, K, transpose_b = True) # [batch_size, seq_len, seq_len]
3 scaled_scores = scores / tf.math.sqrt(tf.cast(self.d_k, tf.float32))
# 应用softmax
4 attention_weights = tf.nn.softmax(scaled_scores, axis = -1) # [batch_size, seq_len, seq_len]
# 上下文向量
context = tf.matmul(attention_weights, V) # [bath_size, seq_len, d_v]
return context
# 示例
input_layer = tf.keras.Input(shape=(10, 64)) # seq_len = 10, dim = 64
5attention_output = CustomAttention(d_k=128, d_v=256)(input_layer)
model = tf.keras.models.Model(inputs=input_layer, outputs=attention_output)
print(model.summary())
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 10, 64)] 0
custom_attention (CustomAtt (None, 10, 256) 32768
ention)
=================================================================
Total params: 32,768
Trainable params: 32,768
Non-trainable params: 0
_________________________________________________________________
None