From 127d62fa06b006b9f07a362fa7abc6517db778d7 Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Mon, 15 Jul 2024 15:15:41 +0800 Subject: [PATCH] fix the unalign size --- ggml/src/ggml-sycl/dmmv.cpp | 42 ++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 0f89c5d18..a1551ab86 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -112,6 +112,7 @@ static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dflo int constexpr Unroll = 2; const int iqs = tid; // x quant index int ncols_pad = ncols - ncols % (WarpK * Unroll); + int ncols_pad1 = ncols - ncols % (WarpK * 1); int i = 0; for (; i < ncols_pad; i += WarpK * Unroll) { #pragma unroll @@ -140,30 +141,33 @@ static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dflo } } } - for (; i < ncols_pad; i += WarpK * 1) { + if (i + WarpK <= ncols_pad1) + { + for (; i < ncols_pad1; i += WarpK * 1) { #pragma unroll - for (int iu = 0; iu < 1; iu++) - { - const int iybs = i + tid * QK4_0 + iu * WarpK; // y block start index - const int ib = (row * ncols + i) / QK4_0 + tid + iu * WARP_SIZE; // x block index - const dfloat d = *(sycl::half *)((char *)x + ncols * nrows / 2 + ib * 2); - sycl::vec tmp_qs = *(sycl::vec*)((char *)x + ib * QK4_0 / 2); - int constexpr KUnroll = 1; -#pragma unroll - for (int ir = 0; ir < QK4_0 / 2; ir += KUnroll) + for (int iu = 0; iu < 1; iu++) { - const int vui = tmp_qs[ir]; - dfloat2 v; - v.x() = ((vui & 0xF) - 8) * d; - v.y() = ((vui >> 4) - 8) * d; + const int iybs = i + tid * QK4_0 + iu * WarpK; // y block start index + const int ib = (row * ncols + i) / QK4_0 + tid + iu * WARP_SIZE; // x block index + const dfloat d = *(sycl::half*)((char*)x + ncols * nrows / 2 + ib * 2); + sycl::vec tmp_qs = *(sycl::vec*)((char*)x + ib * QK4_0 / 2); + int constexpr KUnroll = 1; +#pragma unroll + for (int ir = 0; ir < QK4_0 / 2; ir += KUnroll) + { + const int vui = tmp_qs[ir]; + dfloat2 v; + v.x() = ((vui & 0xF) - 8) * d; + v.y() = ((vui >> 4) - 8) * d; #ifdef GGML_SYCL_F16 - dfloat2 t1{ y[iybs + ir * 2 + 0], - y[iybs + ir * 2 + 1] }; - tmp += v * t1; + dfloat2 t1{ y[iybs + ir * 2 + 0], + y[iybs + ir * 2 + 1] }; + tmp += v * t1; #else - tmp += v.x() * y[iybs + ir * 2 + 0]; - tmp += v.y() * y[iybs + ir * 2 + 1]; + tmp += v.x() * y[iybs + ir * 2 + 0]; + tmp += v.y() * y[iybs + ir * 2 + 1]; #endif + } } } }