pass batch offset for F16 src1

This commit is contained in:
luoyu-intel 2024-01-24 21:46:48 +08:00
parent 5600118221
commit eef5faae18

View file

@ -12098,7 +12098,7 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
if (src1_convert_f16) { if (src1_convert_f16) {
if (src1->type == GGML_TYPE_F16) { if (src1->type == GGML_TYPE_F16) {
src1_dfloat = (sycl::half *)src1->data + row_low * src1_ncols; src1_dfloat = (sycl::half *)src1->data + src1_padded_row_size;
} else { } else {
src1_dfloat = src1_dfloat_a.alloc(ne00); src1_dfloat = src1_dfloat_a.alloc(ne00);
ggml_cpy_f32_f16_sycl((const char *)src1_ddf_i, (char *)src1_dfloat, ggml_cpy_f32_f16_sycl((const char *)src1_ddf_i, (char *)src1_dfloat,
@ -12729,7 +12729,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
const bool src0_is_contiguous = ggml_is_contiguous(src0); const bool src0_is_contiguous = ggml_is_contiguous(src0);
const bool src1_is_contiguous = ggml_is_contiguous(src1); const bool src1_is_contiguous = ggml_is_contiguous(src1);
const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
GGML_ASSERT(!(split && ne02 > 1)); GGML_ASSERT(!(split && ne02 > 1));
@ -12919,7 +12919,9 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream)); SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream));
} }
if (src1->type == GGML_TYPE_F16) {
src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
}
// do the computation // do the computation
op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream); row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream);