Address PR comments
This commit is contained in:
parent
16c24c6444
commit
6baa4ead58
1 changed files with 13 additions and 13 deletions
26
ggml-cuda.cu
26
ggml-cuda.cu
|
@ -169,7 +169,7 @@ typedef struct {
|
||||||
} block_q3_K;
|
} block_q3_K;
|
||||||
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
||||||
|
|
||||||
#define QR4_K 4
|
#define QR4_K 2
|
||||||
#define QI4_K (QK_K / (4*QR4_K))
|
#define QI4_K (QK_K / (4*QR4_K))
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -188,7 +188,7 @@ typedef struct {
|
||||||
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define QR5_K 4
|
#define QR5_K 2
|
||||||
#define QI5_K (QK_K / (4*QR5_K))
|
#define QI5_K (QK_K / (4*QR5_K))
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -1567,7 +1567,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
||||||
|
|
||||||
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
|
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
|
||||||
const int bq8_offset = (QR4_K/2) * (iqs / (QI8_1/2));
|
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
|
||||||
|
|
||||||
float sumf_d = 0.0f;
|
float sumf_d = 0.0f;
|
||||||
float sumf_m = 0.0f;
|
float sumf_m = 0.0f;
|
||||||
|
@ -1597,7 +1597,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
const uint8_t * sc = (const uint8_t *)aux;
|
const uint8_t * sc = (const uint8_t *)aux;
|
||||||
const uint8_t * m = sc + 2;
|
const uint8_t * m = sc + 2;
|
||||||
|
|
||||||
for (int i = 0; i < QR4_K/2; ++i) {
|
for (int i = 0; i < QR4_K; ++i) {
|
||||||
|
|
||||||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
||||||
const float d8i = bq8i->d;
|
const float d8i = bq8i->d;
|
||||||
|
@ -1608,10 +1608,10 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
|
const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
|
||||||
const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;
|
const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;
|
||||||
|
|
||||||
const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0));
|
const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
|
||||||
const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
|
const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
|
||||||
|
|
||||||
sumf_d += d8i * (dot1 * sc[i]); // SIMD dot product
|
sumf_d += d8i * (dot1 * sc[i]);
|
||||||
sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
|
sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1627,7 +1627,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
|
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
|
||||||
|
|
||||||
const int bq8_offset = (QR5_K/2) * (iqs / (QI8_1/2));
|
const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
|
||||||
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
|
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
|
||||||
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
|
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
|
||||||
|
|
||||||
|
@ -1656,7 +1656,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
||||||
const uint8_t * sc = (const uint8_t *)aux;
|
const uint8_t * sc = (const uint8_t *)aux;
|
||||||
const uint8_t * m = sc + 2;
|
const uint8_t * m = sc + 2;
|
||||||
|
|
||||||
for (int i = 0; i < QR5_K/2; ++i) {
|
for (int i = 0; i < QR5_K; ++i) {
|
||||||
|
|
||||||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
||||||
const float d8i = bq8i->d;
|
const float d8i = bq8i->d;
|
||||||
|
@ -1673,11 +1673,11 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
||||||
const int vi1 = vil1 | vih1;
|
const int vi1 = vil1 | vih1;
|
||||||
const int vi2 = vil2 | vih2;
|
const int vi2 = vil2 | vih2;
|
||||||
|
|
||||||
const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0));
|
const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
|
||||||
const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
|
const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
|
||||||
|
|
||||||
sumf_d += d8i * (dot1 * sc[i]); // SIMD dot product
|
sumf_d += d8i * (dot1 * sc[i]);
|
||||||
sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
|
sumf_m += d8i * (dot2 * m[i]);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2349,7 +2349,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, vec_dot_q4_K_q8_1>
|
mul_mat_vec_q<QK_K, QI4_K/2, block_q4_K, vec_dot_q4_K_q8_1>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2358,7 +2358,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
|
mul_mat_vec_q<QK_K, QI5_K/2, block_q5_K, vec_dot_q5_K_q8_1>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue