[SYCL] Fix the sub group size of Intel (#8106)
* use warp_size macro for all sycl kernels * fix mask of permute_sub_group_by_xor * fix rms_norm with correct warp number * fix rms_norm_f32/group_norm_f32 * move norm to norm.cpp file * fix quantize bug * fix mmvq's batch size
This commit is contained in:
parent
5fac350b9c
commit
d08c20edde
9 changed files with 587 additions and 509 deletions
|
@ -76,7 +76,7 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
|||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
|
@ -227,7 +227,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -346,7 +346,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -499,7 +499,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -633,7 +633,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -748,7 +748,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -774,7 +774,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -795,7 +795,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
|||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -816,7 +816,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
|
|||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -837,7 +837,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
|||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -858,7 +858,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
|||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -873,10 +873,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
|||
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, 32);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
@ -889,10 +889,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
|||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, 32);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
@ -905,10 +905,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
|||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, 32);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
@ -918,10 +918,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
|||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const sycl::range<3> block_dims(1, 1, 32);
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
||||
});
|
||||
}
|
||||
|
@ -934,10 +934,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
|||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, 32);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue