f16 still works
This commit is contained in:
parent
1ea2a0036e
commit
08d8a6b528
5 changed files with 24 additions and 14 deletions
|
@ -101,3 +101,11 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
|||
v.y *= d;
|
||||
#endif // GGML_CUDA_F16
|
||||
}
|
||||
|
||||
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
v.x = x[ib + iqs + 0];
|
||||
v.y = x[ib + iqs + 1];
|
||||
}
|
||||
|
|
|
@ -565,14 +565,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
|||
}
|
||||
}
|
||||
|
||||
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
v.x = x[ib + iqs + 0];
|
||||
v.y = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
// qk = quantized weights per x block
|
||||
|
|
|
@ -94,7 +94,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
|||
ggml_tensor * KQV = dst;
|
||||
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
#include "common.cuh"
|
||||
#include "dequantize.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile-f16.cuh"
|
||||
|
||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, typename type_k, int qkk, dequantize_kernel_t dequantize_k> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
@ -48,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_K = nb11 / sizeof(type_k);
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
|
@ -108,7 +110,9 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||
half2 tmp;
|
||||
dequantize_k(K_h, (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, (2*k_KQ)%qkk, tmp);
|
||||
KV_tmp[i_KQ][k_KQ] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -270,13 +274,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, half2, 2, convert_f16>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, half2, 2, convert_f16>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
default: {
|
||||
|
|
|
@ -457,11 +457,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
|
||||
if (ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue