revert cpy code
This commit is contained in:
parent
92784e8059
commit
d9b89eb308
1 changed files with 27 additions and 14 deletions
|
@ -130,12 +130,12 @@ static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dflo
|
|||
v.x() = ((vui & 0xF) - 8) * d;
|
||||
v.y() = ((vui >> 4) - 8) * d;
|
||||
#ifdef GGML_SYCL_F16
|
||||
dfloat2 t1{ y[iybs + ir + 0],
|
||||
y[iybs + ir + QK4_0 / 2] };
|
||||
dfloat2 t1{ y[iybs + ir * 2 + 0],
|
||||
y[iybs + ir * 2 + 1] };
|
||||
tmp += v * t1;
|
||||
#else
|
||||
tmp += v.x() * y[iybs + ir + 0];
|
||||
tmp += v.y() * y[iybs + ir + QK4_0 / 2];
|
||||
tmp += v.x() * y[iybs + ir * 2 + 0];
|
||||
tmp += v.y() * y[iybs + ir * 2 + 1];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -157,12 +157,12 @@ static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dflo
|
|||
v.x() = ((vui & 0xF) - 8) * d;
|
||||
v.y() = ((vui >> 4) - 8) * d;
|
||||
#ifdef GGML_SYCL_F16
|
||||
dfloat2 t1{ y[iybs + ir + 0],
|
||||
y[iybs + ir + QK4_0 / 2] };
|
||||
dfloat2 t1{ y[iybs + ir * 2 + 0],
|
||||
y[iybs + ir * 2 + 1] };
|
||||
tmp += v * t1;
|
||||
#else
|
||||
tmp += v.x() * y[iybs + ir + 0];
|
||||
tmp += v.y() * y[iybs + ir + QK4_0 / 2];
|
||||
tmp += v.x() * y[iybs + ir * 2 + 0];
|
||||
tmp += v.y() * y[iybs + ir * 2 + 1];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -180,12 +180,12 @@ static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dflo
|
|||
v.x() = ((vui & 0xF) - 8) * d;
|
||||
v.y() = ((vui >> 4) - 8) * d;
|
||||
#ifdef GGML_SYCL_F16
|
||||
dfloat2 t1{ y[iybs + ir * QK4_0 + iqs + 0],
|
||||
y[iybs + ir * QK4_0 + iqs + QK4_0 / 2] };
|
||||
dfloat2 t1{ y[iybs + ir * QK4_0 + iqs * 2 + 0],
|
||||
y[iybs + ir * QK4_0 + iqs * 2 + 1] };
|
||||
tmp += v * t1;
|
||||
#else
|
||||
tmp += v.x() * y[iybs + ir * QK4_0 + iqs + 0];
|
||||
tmp += v.y() * y[iybs + ir * QK4_0 + iqs + QK4_0 / 2];
|
||||
tmp += v.x() * y[iybs + ir * QK4_0 + iqs * 2 + 0];
|
||||
tmp += v.y() * y[iybs + ir * QK4_0 + iqs * 2 + 1];
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -904,11 +904,24 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|||
stream->parallel_for(
|
||||
nrows * ncols / QK4_0,
|
||||
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
const block_q4_0 *x = (const block_q4_0 *)vx;
|
||||
int ib = i;
|
||||
typedef sycl::vec<uint8_t, QK4_0 / 2> CT;
|
||||
CT tmp = *(CT *)x[ib].qs;
|
||||
*(CT*)(vx_tmp + ib * QK4_0 / 2) = tmp;
|
||||
for (int j = 0; j < QK4_0 / 2; j += 2)
|
||||
{
|
||||
const int vui = tmp[j];
|
||||
const int vui1 = tmp[j + 1];
|
||||
uint8_t nv = (vui & 0xF) | (vui1 << 4);
|
||||
*(uint8_t *)(vx_tmp + ib * QK4_0 / 2 + j / 2) = nv;
|
||||
}
|
||||
for (int j = 0; j < QK4_0 / 2; j += 2)
|
||||
{
|
||||
const int vui = tmp[j];
|
||||
const int vui1 = tmp[j + 1];
|
||||
uint8_t nv = (vui >> 4) | (vui1 & 0xf0);
|
||||
*(uint8_t *)(vx_tmp + ib * QK4_0 / 2 + j / 2 + QK4_0 / 4) = nv;
|
||||
}
|
||||
*(sycl::half *)(vx_tmp + ncols * nrows / 2 + ib * 2) = x[ib].d;
|
||||
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue