diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 75e0dddf9..bafe08088 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3554,8 +3554,8 @@ static __device__ __forceinline__ void mul_mat_q( #define MMQ_X_Q4_0_RDNA1 64 #define MMQ_Y_Q4_0_RDNA1 64 #define NWARPS_Q4_0_RDNA1 8 -#define MMQ_X_Q4_0_AMPERE 64 -#define MMQ_Y_Q4_0_AMPERE 128 +#define MMQ_X_Q4_0_AMPERE 4 +#define MMQ_Y_Q4_0_AMPERE 32 #define NWARPS_Q4_0_AMPERE 4 #define MMQ_X_Q4_0_PASCAL 64 #define MMQ_Y_Q4_0_PASCAL 64 @@ -3615,8 +3615,8 @@ template static __global__ void #define MMQ_X_Q4_1_RDNA1 64 #define MMQ_Y_Q4_1_RDNA1 64 #define NWARPS_Q4_1_RDNA1 8 -#define MMQ_X_Q4_1_AMPERE 64 -#define MMQ_Y_Q4_1_AMPERE 128 +#define MMQ_X_Q4_1_AMPERE 4 +#define MMQ_Y_Q4_1_AMPERE 32 #define NWARPS_Q4_1_AMPERE 4 #define MMQ_X_Q4_1_PASCAL 64 #define MMQ_Y_Q4_1_PASCAL 64 @@ -3678,8 +3678,8 @@ template static __global__ void #define MMQ_X_Q5_0_RDNA1 64 #define MMQ_Y_Q5_0_RDNA1 64 #define NWARPS_Q5_0_RDNA1 8 -#define MMQ_X_Q5_0_AMPERE 128 -#define MMQ_Y_Q5_0_AMPERE 64 +#define MMQ_X_Q5_0_AMPERE 4 +#define MMQ_Y_Q5_0_AMPERE 32 #define NWARPS_Q5_0_AMPERE 4 #define MMQ_X_Q5_0_PASCAL 64 #define MMQ_Y_Q5_0_PASCAL 64 @@ -3739,8 +3739,8 @@ template static __global__ void #define MMQ_X_Q5_1_RDNA1 64 #define MMQ_Y_Q5_1_RDNA1 64 #define NWARPS_Q5_1_RDNA1 8 -#define MMQ_X_Q5_1_AMPERE 128 -#define MMQ_Y_Q5_1_AMPERE 64 +#define MMQ_X_Q5_1_AMPERE 4 +#define MMQ_Y_Q5_1_AMPERE 32 #define NWARPS_Q5_1_AMPERE 4 #define MMQ_X_Q5_1_PASCAL 64 #define MMQ_Y_Q5_1_PASCAL 64 @@ -3800,8 +3800,8 @@ mul_mat_q5_1( #define MMQ_X_Q8_0_RDNA1 64 #define MMQ_Y_Q8_0_RDNA1 64 #define NWARPS_Q8_0_RDNA1 8 -#define MMQ_X_Q8_0_AMPERE 128 -#define MMQ_Y_Q8_0_AMPERE 64 +#define MMQ_X_Q8_0_AMPERE 4 +#define MMQ_Y_Q8_0_AMPERE 32 #define NWARPS_Q8_0_AMPERE 4 #define MMQ_X_Q8_0_PASCAL 64 #define MMQ_Y_Q8_0_PASCAL 64 @@ -3861,8 +3861,8 @@ template static __global__ void #define MMQ_X_Q2_K_RDNA1 128 #define MMQ_Y_Q2_K_RDNA1 32 #define NWARPS_Q2_K_RDNA1 8 -#define MMQ_X_Q2_K_AMPERE 64 -#define MMQ_Y_Q2_K_AMPERE 128 +#define MMQ_X_Q2_K_AMPERE 4 +#define MMQ_Y_Q2_K_AMPERE 32 #define NWARPS_Q2_K_AMPERE 4 #define MMQ_X_Q2_K_PASCAL 64 #define MMQ_Y_Q2_K_PASCAL 64 @@ -3922,8 +3922,8 @@ mul_mat_q2_K( #define MMQ_X_Q3_K_RDNA1 32 #define MMQ_Y_Q3_K_RDNA1 128 #define NWARPS_Q3_K_RDNA1 8 -#define MMQ_X_Q3_K_AMPERE 128 -#define MMQ_Y_Q3_K_AMPERE 128 +#define MMQ_X_Q3_K_AMPERE 4 +#define MMQ_Y_Q3_K_AMPERE 32 #define NWARPS_Q3_K_AMPERE 4 #define MMQ_X_Q3_K_PASCAL 64 #define MMQ_Y_Q3_K_PASCAL 64 @@ -3985,8 +3985,8 @@ template static __global__ void #define MMQ_X_Q4_K_RDNA1 32 #define MMQ_Y_Q4_K_RDNA1 64 #define NWARPS_Q4_K_RDNA1 8 -#define MMQ_X_Q4_K_AMPERE 64 -#define MMQ_Y_Q4_K_AMPERE 128 +#define MMQ_X_Q4_K_AMPERE 4 +#define MMQ_Y_Q4_K_AMPERE 32 #define NWARPS_Q4_K_AMPERE 4 #define MMQ_X_Q4_K_PASCAL 64 #define MMQ_Y_Q4_K_PASCAL 64 @@ -4048,8 +4048,8 @@ template static __global__ void #define MMQ_X_Q5_K_RDNA1 32 #define MMQ_Y_Q5_K_RDNA1 64 #define NWARPS_Q5_K_RDNA1 8 -#define MMQ_X_Q5_K_AMPERE 64 -#define MMQ_Y_Q5_K_AMPERE 128 +#define MMQ_X_Q5_K_AMPERE 4 +#define MMQ_Y_Q5_K_AMPERE 32 #define NWARPS_Q5_K_AMPERE 4 #define MMQ_X_Q5_K_PASCAL 64 #define MMQ_Y_Q5_K_PASCAL 64 @@ -4109,8 +4109,8 @@ mul_mat_q5_K( #define MMQ_X_Q6_K_RDNA1 32 #define MMQ_Y_Q6_K_RDNA1 64 #define NWARPS_Q6_K_RDNA1 8 -#define MMQ_X_Q6_K_AMPERE 64 -#define MMQ_Y_Q6_K_AMPERE 64 +#define MMQ_X_Q6_K_AMPERE 4 +#define MMQ_Y_Q6_K_AMPERE 32 #define NWARPS_Q6_K_AMPERE 4 #define MMQ_X_Q6_K_PASCAL 64 #define MMQ_Y_Q6_K_PASCAL 64 @@ -7252,7 +7252,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; -#if 0 +#if 1 { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 half * src0_as_f16 = nullptr; @@ -7309,6 +7309,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm } } #else + // NOTE: this seems faster for tiny models and small batch-size { // convert src0 to fp32, multiply as fp32 float * src0_as_f32 = nullptr; @@ -7372,7 +7373,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); - } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 1) { + } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 32) { + // F16 and quantized src0 + high-batch src1 ggml_cuda_mul_mat_mat_deq_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);