CUDA: quantized KV support for FA vec (#7527)

* CUDA: quantized KV support for FA vec

* try CI fix

* fix commented-out kernel variants

* add q8_0 q4_0 tests

* fix nwarps > batch size

* split fattn compile via extern templates

* fix flake8

* fix metal tests

* fix cmake

* make generate_cu_files.py executable

* add autogenerated .cu files

* fix AMD

* error if type_v != FP16 and not flash_attn

* remove obsolete code
This commit is contained in:
Johannes Gäßler 2024-06-01 08:44:14 +02:00 committed by GitHub
parent a323ec60af
commit 9b596417af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
110 changed files with 2697 additions and 1200 deletions

View file

@ -386,7 +386,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
}
return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
return vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
(&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
}
@ -547,7 +547,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds;
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
}