cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X (#8800)

* cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X

* update asserts

* only use dmmv for supported types

* add test
This commit is contained in:
slaren 2024-08-01 15:26:22 +02:00 committed by GitHub
parent c8a0090922
commit 7a11eb3a26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 22 additions and 11 deletions

View file

@ -16,3 +16,5 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);