func.func @attention_head(
%x : tensor<seq x d_model>,
%wq : tensor<d_model x d_k>,
%wk : tensor<d_model x d_k>,
%wv : tensor<d_model x d_v>) {
...
}
Возьмём следующие размеры:
- seq = 4 (длина последовательности 4 токена)
- d_model = 8 (скрытое пространство 8)
- d_k = 4 (размерность ключей/запросов)
- d_v = 4 (размерность значений)
func.func @mini_attention(
%x : tensor<4x8>, // 4 tokens, 8 vectors to define embedding
%wq : tensor<8x4>, // 8 vectors to define embedding,
%wk : tensor<8x4>, // 8 vectors to define embedding
%wv : tensor<8x4> // 8 vectors to define embedding
) -> tensor<4x4> {
...
// 2 tokena могут корелировать но ты не можешь узнть на сколько из их эмбэддинга
// чтобы узнать насколько они корелируют нужно дополнительное пространство которое обучено на то чтобы детектировать кореляции
// чтобы получить это пространство мы умножаем текущий эмбэдинг на wQ и wK
// Q of each = embedding x wQ
%Q = linalg.matmul
ins(%x, %wq : tensor<4x8>, tensor<8x4>)
outs(%q_init : tensor<4x4>) -> tensor<4x4>
// K of each = embedding x wK
%K = linalg.matmul
ins(%x, %wk : tensor<4x8>, tensor<8x4>)
outs(%k_init : tensor<4x4>) -> tensor<4x4>
// тут также для того чтобы корелировать-умножать свапаем размерности
//%K_t = tensor.transpose %K, [1, 0] : tensor<4x4> -> tensor<4x4>
// как сильно что влияет
%scores = linalg.matmul
ins(%Q, %K : tensor<4x4>, tensor<4x4>)
outs(%scores_init : tensor<4x4>) -> tensor<4x4>
// нормализация
%attn = linalg.softmax
ins(%scores : tensor<4x4>)
{axis = 1} -> tensor<4x4>
// дальше нужно сдвинть екущий эмбэддинг в векторном пространстве чз сильно корелирующие тоены/эмбэдинги
// для этого нужно получить пространство которое позволит расчитать как повлияет
%V = linalg.matmul
ins(%x, %wv : tensor<4x8>, tensor<8x4>)
outs(%v_init : tensor<4x4>) -> tensor<4x4>
// Output = attn * V
// сдвигаем эмбэдинг в адресном пространстве
%Y = linalg.matmul
ins(%attn, %V : tensor<4x4>, tensor<4x4>)
outs(%y_init : tensor<4x4>) -> tensor<4x4>
}
def solution(arg):
if arg > 10:
return 42
return 0