use dmmv as default

This commit is contained in:
luoyu-intel 2024-07-16 15:46:49 +08:00
parent a8c75c041d
commit 0b8565d979
2 changed files with 21 additions and 20 deletions

View file

@ -3632,7 +3632,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
// check data types and tensor shapes for custom matrix multiplication kernels:
bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32

View file

@ -883,7 +883,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream,char* vx_tmp) {
dpct::queue_ptr stream) {
#if WARP_SIZE==32
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
@ -911,7 +911,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_0(vx_tmp, y, dst, ncols, nrows, item_ct1);
dequantize_mul_mat_vec_q4_0(vx, y, dst, ncols, nrows, item_ct1);
});
}
#endif
@ -1083,7 +1083,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
const int64_t src1_ncols, const int64_t src1_padded_row_size,
const int64_t src1_ncols, const int64_t src1_padded_col_size,
const dpct::queue_ptr &stream) {
const int64_t ne00 = src0->ne[0];
@ -1111,50 +1111,51 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
#endif // GGML_SYCL_F16
ggml_sycl_pool_alloc<char> src0_test(ctx.pool());
char *src0_test_ptr = src0_test.alloc(ggml_nbytes(src0));
switch (src0->type) {
for (int i = 0; i < src1_ncols; i++)
{
const dfloat* src1_dfloat_bs = src1_dfloat + i * src1_padded_col_size;
float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
switch (src0->type) {
case GGML_TYPE_Q4_0:
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream,src0_test_ptr);
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_1:
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_0:
dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_1:
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q8_0:
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q3_K:
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_K:
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_K:
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q6_K:
dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_F16:
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
default:
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
GGML_ASSERT(false);
break;
}
}
(void) src1;
(void) dst;
(void) src1_ddq_i;
(void) src1_ncols;
(void) src1_padded_row_size;
}