iq2_xxs: slightly faster CUDA dot product
TG-128 is now at 155.1 t/s.
This commit is contained in:
parent
8240521901
commit
c19d0d09ba
1 changed files with 21 additions and 1 deletions
22
ggml-cuda.cu
22
ggml-cuda.cu
|
@ -477,7 +477,7 @@ typedef struct {
|
|||
} block_q6_K;
|
||||
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
||||
|
||||
#define QR2_XXS 4
|
||||
#define QR2_XXS 8
|
||||
#define QI2_XXS (QK_K / (4*QR2_XXS))
|
||||
typedef struct {
|
||||
half d;
|
||||
|
@ -3960,6 +3960,25 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|||
#if QK_K == 256
|
||||
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
|
||||
|
||||
#if QR2_XXS == 8
|
||||
const int ib32 = iqs;
|
||||
const uint16_t * q2 = bq2->qs + 4*ib32;
|
||||
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||
const int8_t * q8 = bq8_1[ib32].qs;
|
||||
uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||
int sumi = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]);
|
||||
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
||||
}
|
||||
q8 += 8;
|
||||
aux32 >>= 7;
|
||||
}
|
||||
const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
|
||||
return d * sumi;
|
||||
#else
|
||||
// iqs is 0...15
|
||||
const int ib32 = iqs/2;
|
||||
const int il = iqs%2;
|
||||
|
@ -3978,6 +3997,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|||
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
|
||||
}
|
||||
return d * (sumi1 + sumi2);
|
||||
#endif
|
||||
#else
|
||||
assert(false);
|
||||
return 0.f;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue