fix mmvq's batch size

This commit is contained in:
luoyu-intel 2024-06-28 11:26:16 +08:00
parent 61f0cd58dc
commit eb0d1325af

View file

@ -936,7 +936,7 @@ void ggml_sycl_op_mul_mat_vec_q(
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
float *dst_dd_i, const int64_t row_low, const int64_t row_high, float *dst_dd_i, const int64_t row_low, const int64_t row_high,
const int64_t src1_ncols, const int64_t src1_padded_row_size, const int64_t src1_ncols, const int64_t src1_padded_col_size,
const dpct::queue_ptr &stream) { const dpct::queue_ptr &stream) {
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
@ -948,77 +948,82 @@ void ggml_sycl_op_mul_mat_vec_q(
int id; int id;
SYCL_CHECK( SYCL_CHECK(
CHECK_TRY_ERROR(id = get_current_device_id())); CHECK_TRY_ERROR(id = get_current_device_id()));
const size_t q8_1_ts = sizeof(block_q8_1);
const size_t q8_1_bs = QK8_1;
// the main device has a larger memory buffer to hold the results from all GPUs // the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into // nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff; const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff;
for (int i = 0; i < src1_ncols; i++)
switch (src0->type) { {
const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
switch (src0->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_M:
mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS:
mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_XS:
mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
} }
}
(void) src1; (void) src1;
(void) dst; (void) dst;
(void) src1_ddf_i; (void) src1_ddf_i;
(void) src1_ncols;
(void) src1_padded_row_size; (void) src1_padded_row_size;
} }