diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 36518ff93..8d82a837e 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -1021,6 +1021,60 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr (void) dst; } +static void get_rows_sycl_q4_0(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const void *src0_dd, + const int32_t *src1_dd, float *dst_dd, + queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + GGML_ASSERT(ne00 % 2 == 0); + + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + 2 * SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2 * SYCL_GET_ROWS_BLOCK_SIZE); + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + uint8_t* src0_q = (uint8_t*)src0_dd; + const size_t ncols = ne00; + const size_t nrows = ne01; + sycl::half* src0_d = (sycl::half*)(src0_q + nrows * ncols / 2); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1)[[intel::reqd_sub_group_size(WARP_SIZE)]] { + const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * 2; + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1_dd[i10 * s10 + i11 * s11 + i12 * s12]; + float* dst_row = dst_dd + i10 * s1 + i11 * s2 + i12 * s3; + const int src0_off = i01 * ncols + i00; + const int vui = src0_q[src0_off / 2]; + float d = src0_d[src0_off / QK4_0]; + dst_row[i00 + 0] = ((vui & 0xF) - 8) * d; + dst_row[i00 + 1] = ((vui >> 4) - 8) * d; + }); +} + template static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -2146,7 +2200,8 @@ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_te get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q4_0: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + //get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl_q4_0(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q4_1: get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); @@ -4283,12 +4338,51 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); SYCL_CHECK( CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); - char* host_buf = (char*)malloc(size); - memcpy(host_buf, data, size); - SYCL_CHECK( - CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size) - .wait())); - free(host_buf); + if (tensor->type == GGML_TYPE_Q4_0) + { + auto tmp_buf = sycl::malloc_shared(size, *stream); + GGML_ASSERT((size % sizeof(block_q4_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); + int blk_offset = offset / sizeof(block_q4_0); + auto qs_ptr = (uint8_t*)tensor->data + blk_offset * QK4_0 / 2; + size_t ncols = tensor->ne[0]; + size_t nrows = tensor->ne[1]; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + blk_offset; + stream->parallel_for( + size / QK4_0, + [=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] { + const block_q4_0* x = (const block_q4_0*)data; + int ib = i; + typedef sycl::vec CT; + CT tmp = *(CT*)x[ib].qs; + for (int j = 0; j < QK4_0 / 2; j += 2) + { + const int vui = tmp[j]; + const int vui1 = tmp[j + 1]; + uint8_t nv = (vui & 0xF) | (vui1 << 4); + *(qs_ptr + ib * QK4_0 / 2 + j / 2) = nv; + } + for (int j = 0; j < QK4_0 / 2; j += 2) + { + const int vui = tmp[j]; + const int vui1 = tmp[j + 1]; + uint8_t nv = (vui >> 4) | (vui1 & 0xf0); + *(qs_ptr + ib * QK4_0 / 2 + j / 2 + QK4_0 / 4) = nv; + } + *(d_ptr + ib) = x[ib].d; + + }); + sycl::free(tmp_buf, *stream); + } + else + { + char* host_buf = (char*)malloc(size); + memcpy(host_buf, data, size); + SYCL_CHECK( + CHECK_TRY_ERROR((*stream).memcpy((char*)tensor->data + offset, host_buf, size) + .wait())); + free(host_buf); + } } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index a1551ab86..2a3a2046b 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -905,30 +905,6 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, GGML_ASSERT(ncols % WARP_SIZE == 0); const sycl::range<3> block_nums(1, 1, nrows); const sycl::range<3> block_dims(1, 1, WARP_SIZE); - stream->parallel_for( - nrows * ncols / QK4_0, - [=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] { - const block_q4_0 *x = (const block_q4_0 *)vx; - int ib = i; - typedef sycl::vec CT; - CT tmp = *(CT *)x[ib].qs; - for (int j = 0; j < QK4_0 / 2; j += 2) - { - const int vui = tmp[j]; - const int vui1 = tmp[j + 1]; - uint8_t nv = (vui & 0xF) | (vui1 << 4); - *(uint8_t *)(vx_tmp + ib * QK4_0 / 2 + j / 2) = nv; - } - for (int j = 0; j < QK4_0 / 2; j += 2) - { - const int vui = tmp[j]; - const int vui1 = tmp[j + 1]; - uint8_t nv = (vui >> 4) | (vui1 & 0xf0); - *(uint8_t *)(vx_tmp + ib * QK4_0 / 2 + j / 2 + QK4_0 / 4) = nv; - } - *(sycl::half *)(vx_tmp + ncols * nrows / 2 + ib * 2) = x[ib].d; - - }); { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});