support new q4_0 layout
This commit is contained in:
parent
127d62fa06
commit
216201230c
2 changed files with 101 additions and 31 deletions
|
@ -1021,6 +1021,60 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
|
|||
(void) dst;
|
||||
}
|
||||
|
||||
static void get_rows_sycl_q4_0(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const void *src0_dd,
|
||||
const int32_t *src1_dd, float *dst_dd,
|
||||
queue_ptr stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
|
||||
// strides in elements
|
||||
//const size_t s0 = nb0 / ggml_element_size(dst);
|
||||
const size_t s1 = nb1 / ggml_element_size(dst);
|
||||
const size_t s2 = nb2 / ggml_element_size(dst);
|
||||
const size_t s3 = nb3 / ggml_element_size(dst);
|
||||
|
||||
const size_t s10 = nb10 / ggml_element_size(src1);
|
||||
const size_t s11 = nb11 / ggml_element_size(src1);
|
||||
const size_t s12 = nb12 / ggml_element_size(src1);
|
||||
//const size_t s13 = nb13 / ggml_element_size(src1);
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
|
||||
const int block_num_x = (ne00 + 2 * SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2 * SYCL_GET_ROWS_BLOCK_SIZE);
|
||||
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
|
||||
uint8_t* src0_q = (uint8_t*)src0_dd;
|
||||
const size_t ncols = ne00;
|
||||
const size_t nrows = ne01;
|
||||
sycl::half* src0_d = (sycl::half*)(src0_q + nrows * ncols / 2);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
||||
item_ct1.get_local_id(2)) * 2;
|
||||
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) / ne12;
|
||||
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i01 = src1_dd[i10 * s10 + i11 * s11 + i12 * s12];
|
||||
float* dst_row = dst_dd + i10 * s1 + i11 * s2 + i12 * s3;
|
||||
const int src0_off = i01 * ncols + i00;
|
||||
const int vui = src0_q[src0_off / 2];
|
||||
float d = src0_d[src0_off / QK4_0];
|
||||
dst_row[i00 + 0] = ((vui & 0xF) - 8) * d;
|
||||
dst_row[i00 + 1] = ((vui >> 4) - 8) * d;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename src0_t>
|
||||
static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
|
@ -2146,7 +2200,8 @@ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_te
|
|||
get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
//get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_sycl_q4_0(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
|
@ -4283,12 +4338,51 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
||||
if (tensor->type == GGML_TYPE_Q4_0)
|
||||
{
|
||||
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
|
||||
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
||||
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
||||
int blk_offset = offset / sizeof(block_q4_0);
|
||||
auto qs_ptr = (uint8_t*)tensor->data + blk_offset * QK4_0 / 2;
|
||||
size_t ncols = tensor->ne[0];
|
||||
size_t nrows = tensor->ne[1];
|
||||
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + blk_offset;
|
||||
stream->parallel_for(
|
||||
size / QK4_0,
|
||||
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const block_q4_0* x = (const block_q4_0*)data;
|
||||
int ib = i;
|
||||
typedef sycl::vec<uint8_t, QK4_0 / 2> CT;
|
||||
CT tmp = *(CT*)x[ib].qs;
|
||||
for (int j = 0; j < QK4_0 / 2; j += 2)
|
||||
{
|
||||
const int vui = tmp[j];
|
||||
const int vui1 = tmp[j + 1];
|
||||
uint8_t nv = (vui & 0xF) | (vui1 << 4);
|
||||
*(qs_ptr + ib * QK4_0 / 2 + j / 2) = nv;
|
||||
}
|
||||
for (int j = 0; j < QK4_0 / 2; j += 2)
|
||||
{
|
||||
const int vui = tmp[j];
|
||||
const int vui1 = tmp[j + 1];
|
||||
uint8_t nv = (vui >> 4) | (vui1 & 0xf0);
|
||||
*(qs_ptr + ib * QK4_0 / 2 + j / 2 + QK4_0 / 4) = nv;
|
||||
}
|
||||
*(d_ptr + ib) = x[ib].d;
|
||||
|
||||
});
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
char* host_buf = (char*)malloc(size);
|
||||
memcpy(host_buf, data, size);
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
|
||||
CHECK_TRY_ERROR((*stream).memcpy((char*)tensor->data + offset, host_buf, size)
|
||||
.wait()));
|
||||
free(host_buf);
|
||||
}
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
|
|
|
@ -905,30 +905,6 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
const sycl::range<3> block_nums(1, 1, nrows);
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
nrows * ncols / QK4_0,
|
||||
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const block_q4_0 *x = (const block_q4_0 *)vx;
|
||||
int ib = i;
|
||||
typedef sycl::vec<uint8_t, QK4_0 / 2> CT;
|
||||
CT tmp = *(CT *)x[ib].qs;
|
||||
for (int j = 0; j < QK4_0 / 2; j += 2)
|
||||
{
|
||||
const int vui = tmp[j];
|
||||
const int vui1 = tmp[j + 1];
|
||||
uint8_t nv = (vui & 0xF) | (vui1 << 4);
|
||||
*(uint8_t *)(vx_tmp + ib * QK4_0 / 2 + j / 2) = nv;
|
||||
}
|
||||
for (int j = 0; j < QK4_0 / 2; j += 2)
|
||||
{
|
||||
const int vui = tmp[j];
|
||||
const int vui1 = tmp[j + 1];
|
||||
uint8_t nv = (vui >> 4) | (vui1 & 0xf0);
|
||||
*(uint8_t *)(vx_tmp + ib * QK4_0 / 2 + j / 2 + QK4_0 / 4) = nv;
|
||||
}
|
||||
*(sycl::half *)(vx_tmp + ncols * nrows / 2 + ib * 2) = x[ib].d;
|
||||
|
||||
});
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue