FP16 v still works

This commit is contained in:
Johannes Gäßler 2024-05-20 11:22:11 +02:00
parent ca6d82885c
commit 8a10e5c03c

View file

@ -53,11 +53,11 @@ static __global__ void flash_attn_tile_ext_f16(
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const float2 * Q_f2 = (const float2 *) Q_f; const float2 * Q_f2 = (const float2 *) Q_f;
const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape const type_v * V_h = (const type_v *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(type_v)*qkk/(sizeof(type_k)*qkv)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0; const half * maskh = (const half *) mask + ne11*ic0;
const int stride_K = nb11 / sizeof(type_k); const int stride_K = nb11 / sizeof(type_k);
const int stride_KV2 = nb11*qkk / (2*sizeof(type_k)); const int stride_V = nb11*qkk / (sizeof(type_v)*qkv);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
@ -217,7 +217,9 @@ static __global__ void flash_attn_tile_ext_f16(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i]; half2 tmp;
dequantize_v(V_h + (k_VKQ_0 + k)*stride_V + (2*i)/qkv, 0, ((2*i)%qkv)/qrv, tmp);
KV_tmp[k][i] = tmp;
} }
} }