Spped up Q4_K on CUDA
GTX1660: 29.5 ms/t -> 25.6 ms/t RTX4080: 8.40 ms/t -> 8.25 ms/t
This commit is contained in:
parent
c2a5188b3c
commit
1af6f38b63
1 changed files with 22 additions and 8 deletions
30
ggml-cuda.cu
30
ggml-cuda.cu
|
@ -169,7 +169,7 @@ typedef struct {
|
||||||
} block_q3_K;
|
} block_q3_K;
|
||||||
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
||||||
|
|
||||||
#define QR4_K 2
|
#define QR4_K 4
|
||||||
#define QI4_K (QK_K / (4*QR4_K))
|
#define QI4_K (QK_K / (4*QR4_K))
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -1566,7 +1566,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
||||||
|
|
||||||
const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6
|
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
|
||||||
|
const int bq8_offset = (QR4_K/2) * (iqs / (QI8_1/2));
|
||||||
|
|
||||||
float sumf_d = 0.0f;
|
float sumf_d = 0.0f;
|
||||||
float sumf_m = 0.0f;
|
float sumf_m = 0.0f;
|
||||||
|
@ -1574,7 +1575,14 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
const float d = bq4_K->d;
|
const float d = bq4_K->d;
|
||||||
const float dmin = bq4_K->dmin;
|
const float dmin = bq4_K->dmin;
|
||||||
|
|
||||||
const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
|
// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
|
||||||
|
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
|
||||||
|
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
|
||||||
|
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
|
||||||
|
|
||||||
|
const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4));
|
||||||
|
const int v1 = q4[0];
|
||||||
|
const int v2 = q4[4];
|
||||||
|
|
||||||
const uint16_t * scales = (const uint16_t *)bq4_K->scales;
|
const uint16_t * scales = (const uint16_t *)bq4_K->scales;
|
||||||
uint16_t aux[2];
|
uint16_t aux[2];
|
||||||
|
@ -1589,16 +1597,22 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
const uint8_t * sc = (const uint8_t *)aux;
|
const uint8_t * sc = (const uint8_t *)aux;
|
||||||
const uint8_t * m = sc + 2;
|
const uint8_t * m = sc + 2;
|
||||||
|
|
||||||
for (int i = 0; i < QR4_K; ++i) {
|
for (int i = 0; i < QR4_K/2; ++i) {
|
||||||
|
|
||||||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
||||||
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
|
|
||||||
const float d8i = bq8i->d;
|
const float d8i = bq8i->d;
|
||||||
|
const int * q8 = (const int *)bq8i->qs + (iqs%4);
|
||||||
|
const int ui1 = q8[0];
|
||||||
|
const int ui2 = q8[4];
|
||||||
|
|
||||||
const int vi = (v >> (4*i)) & 0x0F0F0F0F;
|
const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
|
||||||
|
const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;
|
||||||
|
|
||||||
sumf_d += d8i * (__dp4a(vi, ui, 0) * sc[i]); // SIMD dot product
|
const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0));
|
||||||
sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]); // multiply constant part of q4_K with sum of q8_1 values
|
const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
|
||||||
|
|
||||||
|
sumf_d += d8i * (dot1 * sc[i]); // SIMD dot product
|
||||||
|
sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
|
||||||
}
|
}
|
||||||
|
|
||||||
return d*sumf_d - dmin*sumf_m;
|
return d*sumf_d - dmin*sumf_m;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue