diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 9523e0934..b510777fb 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1885,7 +1885,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); - bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) + bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 8968bae40..96a5adef5 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -672,3 +672,12 @@ void ggml_cuda_op_dequantize_mul_mat_vec( GGML_UNUSED(src1_ncols); GGML_UNUSED(src1_padded_row_size); } + +bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) { + return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 || + src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 || + src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K || + src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K || + src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K || + src0_type == GGML_TYPE_F16; +} diff --git a/ggml/src/ggml-cuda/dmmv.cuh b/ggml/src/ggml-cuda/dmmv.cuh index 4c5ebd475..e727eb97f 100644 --- a/ggml/src/ggml-cuda/dmmv.cuh +++ b/ggml/src/ggml-cuda/dmmv.cuh @@ -16,3 +16,5 @@ void ggml_cuda_op_dequantize_mul_mat_vec( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); + +bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);