use dmmv as default
This commit is contained in:
parent
a8c75c041d
commit
0b8565d979
2 changed files with 21 additions and 20 deletions
|
@ -3632,7 +3632,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
// check data types and tensor shapes for custom matrix multiplication kernels:
|
||||
bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
||||
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
|
|
|
@ -883,7 +883,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
|||
static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream,char* vx_tmp) {
|
||||
dpct::queue_ptr stream) {
|
||||
#if WARP_SIZE==32
|
||||
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
|
@ -911,7 +911,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q4_0(vx_tmp, y, dst, ncols, nrows, item_ct1);
|
||||
dequantize_mul_mat_vec_q4_0(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
@ -1083,7 +1083,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|||
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
|
||||
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
|
||||
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
||||
const int64_t src1_ncols, const int64_t src1_padded_col_size,
|
||||
const dpct::queue_ptr &stream) {
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
|
@ -1111,50 +1111,51 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|||
#endif // GGML_SYCL_F16
|
||||
ggml_sycl_pool_alloc<char> src0_test(ctx.pool());
|
||||
char *src0_test_ptr = src0_test.alloc(ggml_nbytes(src0));
|
||||
|
||||
for (int i = 0; i < src1_ncols; i++)
|
||||
{
|
||||
const dfloat* src1_dfloat_bs = src1_dfloat + i * src1_padded_col_size;
|
||||
float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream,src0_test_ptr);
|
||||
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_ddq_i;
|
||||
(void) src1_ncols;
|
||||
(void) src1_padded_row_size;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue