Merge pull request #1 from arthw/update_warp

[SYCL] Fix WARP_SIZE=16 bug of Intel GPU (#8266) cherry-pick b549a1bbef
This commit is contained in:
Neo Zhang 2024-07-13 16:44:28 +08:00 committed by GitHub
commit aeaed61904
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 203 additions and 70 deletions

View file

@ -490,7 +490,7 @@ if (GGML_SYCL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
add_compile_definitions(GGML_SYCL_WARP_SIZE=32) add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
else() else()
add_compile_definitions(GGML_SYCL_WARP_SIZE=32) add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
endif() endif()
file(GLOB GGML_HEADERS_SYCL "ggml-sycl/*.hpp") file(GLOB GGML_HEADERS_SYCL "ggml-sycl/*.hpp")

View file

@ -906,6 +906,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
const int nthreads = block_size;
const int nwarps = nthreads / WARP_SIZE;
int nreduce = nwarps / WARP_SIZE;
float slope = 1.0f; float slope = 1.0f;
@ -919,7 +923,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
slope = sycl::pow(base, float(exp)); slope = sycl::pow(base, float(exp));
} }
float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols; float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
float max_val = -INFINITY; float max_val = -INFINITY;
for (int col0 = 0; col0 < ncols; col0 += block_size) { for (int col0 = 0; col0 < ncols; col0 += block_size) {
@ -943,6 +947,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
if (block_size > WARP_SIZE) { if (block_size > WARP_SIZE) {
if (warp_id == 0) { if (warp_id == 0) {
buf[lane_id] = -INFINITY; buf[lane_id] = -INFINITY;
for (size_t i = 1; i < nreduce; i += 1)
buf[lane_id + i * WARP_SIZE] = -INFINITY;
} }
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
@ -952,6 +959,11 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
max_val = buf[lane_id]; max_val = buf[lane_id];
for (size_t i = 1; i < nreduce; i += 1)
{
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
}
max_val = warp_reduce_max(max_val, item_ct1); max_val = warp_reduce_max(max_val, item_ct1);
} }
@ -975,6 +987,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
if (warp_id == 0) { if (warp_id == 0) {
buf[lane_id] = 0.f; buf[lane_id] = 0.f;
for (size_t i = 1; i < nreduce; i += 1)
buf[lane_id + i * WARP_SIZE] = 0.f;
} }
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
@ -984,6 +999,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
tmp = buf[lane_id]; tmp = buf[lane_id];
for (size_t i = 1; i < nreduce; i += 1)
{
tmp += buf[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1); tmp = warp_reduce_sum(tmp, item_ct1);
} }

View file

@ -314,6 +314,7 @@ void sycl_device_mgr::detect_all_sycl_device_list() try {
dpct::get_device_info(prop, device); dpct::get_device_info(prop, device);
work_group_sizes.push_back(prop.get_max_work_group_size()); work_group_sizes.push_back(prop.get_max_work_group_size());
max_compute_units.push_back(prop.get_max_compute_units()); max_compute_units.push_back(prop.get_max_compute_units());
hw_familys.push_back(get_device_family(&device));
} }
return; return;
} catch (sycl::exception const &exc) { } catch (sycl::exception const &exc) {
@ -498,4 +499,8 @@ int ggml_sycl_device_info::get_device_id(int device_index) {
} }
} }
int ggml_sycl_device_info::hw_family(int device_id) {
return device_mgr->hw_familys[device_id];
}
//--ggml_sycl_device_info-- //--ggml_sycl_device_info--

View file

@ -20,6 +20,7 @@
#include "dpct/helper.hpp" #include "dpct/helper.hpp"
#include "ggml-sycl.h" #include "ggml-sycl.h"
#include "presets.hpp" #include "presets.hpp"
#include "sycl_hw.hpp"
#define GGML_COMMON_DECL_SYCL #define GGML_COMMON_DECL_SYCL
#define GGML_COMMON_IMPL_SYCL #define GGML_COMMON_IMPL_SYCL
@ -188,6 +189,8 @@ class sycl_device_mgr {
std::vector<sycl::device> devices; std::vector<sycl::device> devices;
std::vector<int> max_compute_units; std::vector<int> max_compute_units;
std::vector<int> work_group_sizes; std::vector<int> work_group_sizes;
std::vector<int> hw_familys;
sycl::queue *first_queue; sycl::queue *first_queue;
std::vector<sycl::queue> _queues; std::vector<sycl::queue> _queues;
std::vector<sycl::context> ctxs; std::vector<sycl::context> ctxs;
@ -236,6 +239,7 @@ struct ggml_sycl_device_info {
bool is_allowed_device(int device_id); bool is_allowed_device(int device_id);
const char* devices_list(); const char* devices_list();
int get_device_id(int device_index); int get_device_id(int device_index);
int hw_family(int device_id);
}; };
struct ggml_sycl_pool { struct ggml_sycl_pool {

View file

@ -20,8 +20,10 @@ static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 &
} }
template <int qk, int qr, dequantize_kernel_t dequantize_kernel> template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst,
const int ncols, const int nrows, const int warp_size,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
// qk = quantized weights per x block // qk = quantized weights per x block
// qr = number of quantized weights per data value in x block // qr = number of quantized weights per data value in x block
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
@ -34,7 +36,7 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
const int tid = item_ct1.get_local_id(2); const int tid = item_ct1.get_local_id(2);
const int iter_stride = 2*GGML_SYCL_DMMV_X; const int iter_stride = 2*GGML_SYCL_DMMV_X;
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int vals_per_iter = iter_stride / warp_size; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
// partial sum for each thread // partial sum for each thread
@ -76,7 +78,7 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = warp_size / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -93,21 +95,32 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols, float *dst, const int ncols,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream,
int device_id) {
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (ggml_sycl_info().hw_family(device_id) == SYCL_HW_FAMILY_INTEL_IGPU) {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_32_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1)
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
nrows, item_ct1); dequantize_mul_mat_vec<1, 1, convert_f16>(
}); vx, y, dst, ncols, nrows, WARP_32_SIZE, item_ct1);
});
} else {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<1, 1, convert_f16>(
vx, y, dst, ncols, nrows, WARP_SIZE, item_ct1);
});
} }
} }
@ -227,7 +240,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = WARP_32_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -346,7 +359,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = WARP_32_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -499,7 +512,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = WARP_32_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -633,7 +646,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = WARP_32_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -748,7 +761,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = WARP_32_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -762,21 +775,31 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols, float *dst, const int ncols,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream, int device_id) {
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (ggml_sycl_info().hw_family(device_id) == SYCL_HW_FAMILY_INTEL_IGPU) {
// printf("dequantize_mul_mat_vec_q4_0_sycl warp_size=%d\n", WARP_32_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_32_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
vx, y, dst, ncols, nrows, WARP_32_SIZE, item_ct1);
});
} else {
// printf("dequantize_mul_mat_vec_q4_0_sycl warp_size=%d\n", WARP_SIZE);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>( dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, WARP_SIZE, item_ct1);
}); });
} }
} }
@ -784,20 +807,27 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y, static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols, float *dst, const int ncols,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream, int device_id) {
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (ggml_sycl_info().hw_family(device_id) == SYCL_HW_FAMILY_INTEL_IGPU) {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_32_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
vx, y, dst, ncols, nrows, WARP_32_SIZE, item_ct1);
});
} else {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>( dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, WARP_SIZE, item_ct1);
}); });
} }
} }
@ -805,20 +835,27 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y, static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols, float *dst, const int ncols,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream, int device_id) {
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (ggml_sycl_info().hw_family(device_id) == SYCL_HW_FAMILY_INTEL_IGPU) {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_32_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
vx, y, dst, ncols, nrows, WARP_32_SIZE, item_ct1);
});
} else {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>( dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, WARP_SIZE, item_ct1);
}); });
} }
} }
@ -826,20 +863,27 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y, static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols, float *dst, const int ncols,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream, int device_id) {
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (ggml_sycl_info().hw_family(device_id) == SYCL_HW_FAMILY_INTEL_IGPU) {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_32_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
vx, y, dst, ncols, nrows, WARP_32_SIZE, item_ct1);
});
} else {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>( dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, WARP_SIZE, item_ct1);
}); });
} }
} }
@ -847,20 +891,27 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y, static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols, float *dst, const int ncols,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream, int device_id) {
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (ggml_sycl_info().hw_family(device_id) == SYCL_HW_FAMILY_INTEL_IGPU) {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_32_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
vx, y, dst, ncols, nrows, WARP_32_SIZE, item_ct1);
});
} else {
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>( dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
vx, y, dst, ncols, nrows, item_ct1); vx, y, dst, ncols, nrows, WARP_SIZE, item_ct1);
}); });
} }
} }
@ -873,10 +924,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, WARP_SIZE); const sycl::range<3> block_dims(1, ny, WARP_32_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
@ -889,10 +940,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
const int ny = 2 / K_QUANTS_PER_ITERATION; const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, WARP_SIZE); const sycl::range<3> block_dims(1, ny, WARP_32_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
@ -905,10 +956,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
const int ny = 2 / K_QUANTS_PER_ITERATION; const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, WARP_SIZE); const sycl::range<3> block_dims(1, ny, WARP_32_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
@ -918,10 +969,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_32_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
}); });
} }
@ -934,10 +985,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
const int ny = 2 / K_QUANTS_PER_ITERATION; const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, WARP_SIZE); const sycl::range<3> block_dims(1, ny, WARP_32_SIZE);
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_32_SIZE)]] {
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
@ -976,19 +1027,19 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, ctx.device);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, ctx.device);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, ctx.device);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, ctx.device);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, ctx.device);
break; break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
@ -1006,7 +1057,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break; break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, ctx.device);
break; break;
default: default:
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);

View file

@ -57,6 +57,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
const int nwarps = nthreads / WARP_SIZE; const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0); assert(nwarps % WARP_SIZE == 0);
start += item_ct1.get_local_id(2); start += item_ct1.get_local_id(2);
int nreduce = nwarps / WARP_SIZE;
if (end >= ne_elements) { if (end >= ne_elements) {
end = ne_elements; end = ne_elements;
@ -87,7 +88,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
*/ */
item_ct1.barrier(); item_ct1.barrier();
tmp = 0.f; tmp = 0.f;
int nreduce = nwarps / WARP_SIZE;
for (size_t i = 0; i < nreduce; i += 1) for (size_t i = 0; i < nreduce; i += 1)
{ {
tmp += s_sum[lane_id + i * WARP_SIZE]; tmp += s_sum[lane_id + i * WARP_SIZE];
@ -122,7 +122,11 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
better performance if there is no access to global memory. better performance if there is no access to global memory.
*/ */
item_ct1.barrier(); item_ct1.barrier();
tmp = s_sum[lane_id]; tmp = 0.f;
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += s_sum[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1); tmp = warp_reduce_sum(tmp, item_ct1);
} }
@ -186,13 +190,15 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
if (ncols < 1024) { if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) { stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
sycl::range<1>(32), cgh);
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, eps, item_ct1, norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE); s_sum_acc_ct1.get_pointer(), WARP_SIZE);
}); });
}); });
} }
@ -227,6 +233,8 @@ static void group_norm_f32_sycl(const float* x, float* dst,
if (group_size < 1024) { if (group_size < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) { stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
cgh);
const float eps_ct4 = eps; const float eps_ct4 = eps;
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
@ -235,7 +243,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32( group_norm_f32(
x, dst, group_size, ne_elements, eps_ct4, item_ct1, x, dst, group_size, ne_elements, eps_ct4, item_ct1,
nullptr, WARP_SIZE); s_sum_acc_ct1.get_pointer(), WARP_SIZE);
}); });
}); });
} }
@ -275,13 +283,15 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
if (ncols < 1024) { if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) { stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
cgh);
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, eps, item_ct1, rms_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE); s_sum_acc_ct1.get_pointer(), WARP_SIZE);
}); });
}); });
} }

View file

@ -17,6 +17,8 @@
#define GGML_SYCL_MAX_BUFFERS 256 #define GGML_SYCL_MAX_BUFFERS 256
#define WARP_SIZE GGML_SYCL_WARP_SIZE #define WARP_SIZE GGML_SYCL_WARP_SIZE
#define WARP_32_SIZE 32
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#define SYCL_GELU_BLOCK_SIZE 256 #define SYCL_GELU_BLOCK_SIZE 256
@ -62,4 +64,5 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
#define MUL_MAT_SRC1_COL_STRIDE 128 #define MUL_MAT_SRC1_COL_STRIDE 128
#endif // GGML_SYCL_PRESETS_HPP #endif // GGML_SYCL_PRESETS_HPP

View file

@ -0,0 +1,17 @@
#include "sycl_hw.hpp"
bool is_in_vector(const std::vector<int> &vec, int item) {
return std::find(vec.begin(), vec.end(), item) != vec.end();
}
SYCL_HW_FAMILY get_device_family(sycl::device *device_ptr) {
auto id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
auto id_prefix = id & 0xff00;
if (is_in_vector(Xe_Iris_IDs, id_prefix) or is_in_vector(UHD_IDs, id_prefix)) {
return SYCL_HW_FAMILY_INTEL_IGPU;
} else {
std::cerr << "No support PCI_ID: " << std::hex << id << std::endl;
return SYCL_HW_FAMILY_UNKNOWN;
}
}

View file

@ -0,0 +1,24 @@
#ifndef SYCL_HW_HPP
#define SYCL_HW_HPP
#include <algorithm>
#include <stdio.h>
#include <vector>
#include <sycl/sycl.hpp>
// const int Xe_ARC[] = {0x5600, 0x4f};
const std::vector<int> Xe_Iris_IDs = {0x4900, 0xa700};
const std::vector<int> UHD_IDs = {0x4600};
enum SYCL_HW_FAMILY {
SYCL_HW_FAMILY_UNKNOWN = -1,
SYCL_HW_FAMILY_INTEL_IGPU = 0
};
bool is_in_vector(std::vector<int> &vec, int item);
SYCL_HW_FAMILY get_device_family(sycl::device *device_ptr);
#endif // SYCL_HW_HPP