← Brainpick Wiki

LLM MLIR attention

LLM MLIR attention.md

func.func @attention_head(
	%x  : tensor<seq x d_model>,
	%wq : tensor<d_model x d_k>,
	%wk : tensor<d_model x d_k>,
	%wv : tensor<d_model x d_v>) {
	...	
}

Возьмём следующие размеры:

  • seq = 4 (длина последовательности 4 токена)
  • d_model = 8 (скрытое пространство 8)
  • d_k = 4 (размерность ключей/запросов)
  • d_v = 4 (размерность значений)

func.func @mini_attention(
    %x  : tensor<4x8>, // 4 tokens, 8 vectors to define embedding
    %wq : tensor<8x4>, // 8 vectors to define embedding, 
    %wk : tensor<8x4>, // 8 vectors to define embedding
    %wv : tensor<8x4>  // 8 vectors to define embedding
) -> tensor<4x4> {
	...
	
	// 2 tokena могут корелировать но ты не можешь узнть на сколько из их эмбэддинга
	// чтобы узнать насколько они корелируют нужно дополнительное пространство которое обучено на то чтобы детектировать кореляции
	// чтобы получить это пространство мы умножаем текущий эмбэдинг на wQ и wK
	
	// Q of each = embedding x wQ
	%Q = linalg.matmul
		ins(%x, %wq : tensor<4x8>, tensor<8x4>)
		outs(%q_init : tensor<4x4>) -> tensor<4x4>
	
	// K of each = embedding x wK
	%K = linalg.matmul
		ins(%x, %wk : tensor<4x8>, tensor<8x4>)
		outs(%k_init : tensor<4x4>) -> tensor<4x4>
	
	// тут также для того чтобы корелировать-умножать свапаем размерности
	//%K_t = tensor.transpose %K, [1, 0] : tensor<4x4> -> tensor<4x4>
	
	// как сильно что влияет
	%scores = linalg.matmul
		ins(%Q, %K : tensor<4x4>, tensor<4x4>)
		outs(%scores_init : tensor<4x4>) -> tensor<4x4>
	
	// нормализация	
	%attn = linalg.softmax
	    ins(%scores : tensor<4x4>)
	    {axis = 1} -> tensor<4x4>
	    
	// дальше нужно сдвинть екущий эмбэддинг в векторном пространстве чз сильно корелирующие тоены/эмбэдинги
	// для этого нужно получить пространство которое позволит расчитать как повлияет
	%V = linalg.matmul
	    ins(%x, %wv : tensor<4x8>, tensor<8x4>)
		outs(%v_init : tensor<4x4>) -> tensor<4x4>
	    
	// Output = attn * V
	// сдвигаем эмбэдинг в адресном пространстве
	%Y = linalg.matmul
		ins(%attn, %V : tensor<4x4>, tensor<4x4>)
		outs(%y_init : tensor<4x4>) -> tensor<4x4>

}

def solution(arg):
	if arg > 10:
		return 42
	return 0