fix the unalign size

This commit is contained in:
luoyu-intel 2024-07-15 15:15:41 +08:00
parent d9b89eb308
commit 127d62fa06

View file

@ -112,6 +112,7 @@ static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dflo
int constexpr Unroll = 2; int constexpr Unroll = 2;
const int iqs = tid; // x quant index const int iqs = tid; // x quant index
int ncols_pad = ncols - ncols % (WarpK * Unroll); int ncols_pad = ncols - ncols % (WarpK * Unroll);
int ncols_pad1 = ncols - ncols % (WarpK * 1);
int i = 0; int i = 0;
for (; i < ncols_pad; i += WarpK * Unroll) { for (; i < ncols_pad; i += WarpK * Unroll) {
#pragma unroll #pragma unroll
@ -140,30 +141,33 @@ static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dflo
} }
} }
} }
for (; i < ncols_pad; i += WarpK * 1) { if (i + WarpK <= ncols_pad1)
{
for (; i < ncols_pad1; i += WarpK * 1) {
#pragma unroll #pragma unroll
for (int iu = 0; iu < 1; iu++) for (int iu = 0; iu < 1; iu++)
{
const int iybs = i + tid * QK4_0 + iu * WarpK; // y block start index
const int ib = (row * ncols + i) / QK4_0 + tid + iu * WARP_SIZE; // x block index
const dfloat d = *(sycl::half *)((char *)x + ncols * nrows / 2 + ib * 2);
sycl::vec<uint8_t, QK4_0 / 2> tmp_qs = *(sycl::vec<uint8_t, QK4_0 / 2>*)((char *)x + ib * QK4_0 / 2);
int constexpr KUnroll = 1;
#pragma unroll
for (int ir = 0; ir < QK4_0 / 2; ir += KUnroll)
{ {
const int vui = tmp_qs[ir]; const int iybs = i + tid * QK4_0 + iu * WarpK; // y block start index
dfloat2 v; const int ib = (row * ncols + i) / QK4_0 + tid + iu * WARP_SIZE; // x block index
v.x() = ((vui & 0xF) - 8) * d; const dfloat d = *(sycl::half*)((char*)x + ncols * nrows / 2 + ib * 2);
v.y() = ((vui >> 4) - 8) * d; sycl::vec<uint8_t, QK4_0 / 2> tmp_qs = *(sycl::vec<uint8_t, QK4_0 / 2>*)((char*)x + ib * QK4_0 / 2);
int constexpr KUnroll = 1;
#pragma unroll
for (int ir = 0; ir < QK4_0 / 2; ir += KUnroll)
{
const int vui = tmp_qs[ir];
dfloat2 v;
v.x() = ((vui & 0xF) - 8) * d;
v.y() = ((vui >> 4) - 8) * d;
#ifdef GGML_SYCL_F16 #ifdef GGML_SYCL_F16
dfloat2 t1{ y[iybs + ir * 2 + 0], dfloat2 t1{ y[iybs + ir * 2 + 0],
y[iybs + ir * 2 + 1] }; y[iybs + ir * 2 + 1] };
tmp += v * t1; tmp += v * t1;
#else #else
tmp += v.x() * y[iybs + ir * 2 + 0]; tmp += v.x() * y[iybs + ir * 2 + 0];
tmp += v.y() * y[iybs + ir * 2 + 1]; tmp += v.y() * y[iybs + ir * 2 + 1];
#endif #endif
}
} }
} }
} }