FP16 v still works
This commit is contained in:
parent
ca6d82885c
commit
8a10e5c03c
1 changed files with 6 additions and 4 deletions
|
@ -53,11 +53,11 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) Q_f;
|
||||
const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape
|
||||
const type_v * V_h = (const type_v *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(type_v)*qkk/(sizeof(type_k)*qkv)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_K = nb11 / sizeof(type_k);
|
||||
const int stride_KV2 = nb11*qkk / (2*sizeof(type_k));
|
||||
const int stride_V = nb11*qkk / (sizeof(type_v)*qkv);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
@ -217,7 +217,9 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
|
||||
half2 tmp;
|
||||
dequantize_v(V_h + (k_VKQ_0 + k)*stride_V + (2*i)/qkv, 0, ((2*i)%qkv)/qrv, tmp);
|
||||
KV_tmp[k][i] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue