From 6baa4ead5890fbb36c95238c6474795692dbe268 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 23 Jul 2023 20:13:23 +0300 Subject: [PATCH] Address PR comments --- ggml-cuda.cu | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 38c3b760b..73be3a3cd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -169,7 +169,7 @@ typedef struct { } block_q3_K; //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); -#define QR4_K 4 +#define QR4_K 2 #define QI4_K (QK_K / (4*QR4_K)) #ifdef GGML_QKK_64 typedef struct { @@ -188,7 +188,7 @@ typedef struct { static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); #endif -#define QR5_K 4 +#define QR5_K 2 #define QI5_K (QK_K / (4*QR5_K)) #ifdef GGML_QKK_64 typedef struct { @@ -1567,7 +1567,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const block_q4_K * bq4_K = (const block_q4_K *) vbq; // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6 - const int bq8_offset = (QR4_K/2) * (iqs / (QI8_1/2)); + const int bq8_offset = QR4_K * (iqs / (QI8_1/2)); float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -1597,7 +1597,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const uint8_t * sc = (const uint8_t *)aux; const uint8_t * m = sc + 2; - for (int i = 0; i < QR4_K/2; ++i) { + for (int i = 0; i < QR4_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; const float d8i = bq8i->d; @@ -1608,10 +1608,10 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F; const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F; - const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); + const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); - sumf_d += d8i * (dot1 * sc[i]); // SIMD dot product + sumf_d += d8i * (dot1 * sc[i]); sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values } @@ -1627,7 +1627,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q5_K * bq5_K = (const block_q5_K *) vbq; - const int bq8_offset = (QR5_K/2) * (iqs / (QI8_1/2)); + const int bq8_offset = QR5_K * (iqs / (QI8_1/2)); const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4)); const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4)); @@ -1656,7 +1656,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( const uint8_t * sc = (const uint8_t *)aux; const uint8_t * m = sc + 2; - for (int i = 0; i < QR5_K/2; ++i) { + for (int i = 0; i < QR5_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; const float d8i = bq8i->d; @@ -1673,11 +1673,11 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( const int vi1 = vil1 | vih1; const int vi2 = vil2 | vih2; - const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); + const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); - sumf_d += d8i * (dot1 * sc[i]); // SIMD dot product - sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + sumf_d += d8i * (dot1 * sc[i]); + sumf_m += d8i * (dot2 * m[i]); } @@ -2349,7 +2349,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2358,7 +2358,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); }