CUDA: quantized KV support for FA vec

This commit is contained in:
Johannes Gäßler 2024-05-21 19:38:25 +02:00
parent 10b1e45876
commit 672244a88b
11 changed files with 826 additions and 142 deletions

View file

@ -180,8 +180,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
#define VDR_Q8_0_Q8_1_MMVQ 2
#define VDR_Q8_0_Q8_1_MMQ 8
template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
const int * v, const int * u, const float & d8_0, const float & d8_1) {
template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(
const int * v, const int * u, const T & d8_0, const T & d8_1) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
int sumi = 0;
@ -192,7 +192,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
sumi = __dp4a(v[i], u[i], sumi);
}
return d8_0*d8_1 * sumi;
return d8_0*d8_1 * ((T) sumi);
#else
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@ -656,7 +656,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
}
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
}
static __device__ __forceinline__ float vec_dot_q2_K_q8_1(