iq2_xxs: quantized CUDA dot product (MMVQ)
We get TG-128 = 153.1 t/s
This commit is contained in:
parent
06e6908a6b
commit
8240521901
1 changed files with 41 additions and 0 deletions
41
ggml-cuda.cu
41
ggml-cuda.cu
|
@ -3955,6 +3955,35 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
|||
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
#if QK_K == 256
|
||||
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
|
||||
|
||||
// iqs is 0...15
|
||||
const int ib32 = iqs/2;
|
||||
const int il = iqs%2;
|
||||
const uint16_t * q2 = bq2->qs + 4*ib32;
|
||||
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||
const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
|
||||
const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
|
||||
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
|
||||
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
||||
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
|
||||
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
|
||||
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
|
||||
}
|
||||
return d * (sumi1 + sumi2);
|
||||
#else
|
||||
assert(false);
|
||||
return 0.f;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||
static __device__ __forceinline__ void mul_mat_q(
|
||||
|
@ -6055,6 +6084,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
|
|||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
@ -7619,6 +7657,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue