add new q40 dmmv
This commit is contained in:
parent
e02b597be3
commit
d838096ebe
1 changed files with 72 additions and 0 deletions
|
@ -91,6 +91,63 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
|
|||
}
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1);
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
|
||||
#else
|
||||
float tmp = 0.0f;
|
||||
#endif // GGML_SYCL_F16
|
||||
int constexpr ColTile = QK4_0 / WARP_SIZE;
|
||||
static_assert(QK4_0 % WARP_SIZE == 0);
|
||||
static_assert(ColTile == 2);
|
||||
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
for (int i = 0; i < ncols; i += QK4_0) {
|
||||
const int col = i + tid * ColTile;
|
||||
const int ib = (row * ncols + col) / QK4_0; // x block index
|
||||
const int iqs = (col % QK4_0) / QR4_0; // x quant index
|
||||
const int iybs = col - col % QK4_0; // y block start index
|
||||
const dfloat d = x[ib].d;
|
||||
|
||||
const int vui = x[ib].qs[iqs];
|
||||
dfloat2 v;
|
||||
v.x() = (vui & 0xF) * d;
|
||||
v.y() = (vui >> 4) * d;
|
||||
#ifdef GGML_SYCL_F16
|
||||
dfloat2 t1{ y[iybs + iqs + 0],
|
||||
y[iybs + iqs + QK4_0 / 2] };
|
||||
tmp += v * t1;
|
||||
#else
|
||||
tmp += v.x() * y[iybs + iqs + 0];
|
||||
tmp += v.y() * y[iybs + iqs + QK4_0 / 2];
|
||||
#endif
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
#ifdef GGML_SYCL_F16
|
||||
dst[row] = tmp.x() + tmp.y();
|
||||
#else
|
||||
dst[row] = tmp;
|
||||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
}
|
||||
|
||||
static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
|
@ -764,6 +821,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
#if WARP_SIZE==32
|
||||
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
|
||||
|
@ -780,6 +838,20 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
#else
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
const sycl::range<3> block_nums(1, 1, nrows);
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q4_0(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue