From 8a10e5c03cd6cf1be463ab57652b9e8f195663d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 11:22:11 +0200 Subject: [PATCH] FP16 v still works --- ggml-cuda/fattn-tile-f16.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 0bcfda7b4..2e3baa793 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -53,11 +53,11 @@ static __global__ void flash_attn_tile_ext_f16( const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const float2 * Q_f2 = (const float2 *) Q_f; const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape + const type_v * V_h = (const type_v *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(type_v)*qkk/(sizeof(type_k)*qkv)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const int stride_K = nb11 / sizeof(type_k); - const int stride_KV2 = nb11*qkk / (2*sizeof(type_k)); + const int stride_K = nb11 / sizeof(type_k); + const int stride_V = nb11*qkk / (sizeof(type_v)*qkv); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -217,7 +217,9 @@ static __global__ void flash_attn_tile_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i]; + half2 tmp; + dequantize_v(V_h + (k_VKQ_0 + k)*stride_V + (2*i)/qkv, 0, ((2*i)%qkv)/qrv, tmp); + KV_tmp[k][i] = tmp; } }