Compare commits

...
Sign in to create a new pull request.

3 commits

Author SHA1 Message Date
Iwan Kawrakow
9f805264dc Attempt 2 2024-03-12 18:40:13 +02:00
Iwan Kawrakow
9188523f70 iq1_s[SYCL]: remove unnecessary (unused) data 2024-03-12 15:20:04 +02:00
Iwan Kawrakow
da5a6f05f6 iq1_s: attempt to fix SYCL 2024-03-12 15:09:07 +02:00

View file

@ -4701,9 +4701,7 @@ static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restr
template<typename dst_t> template<typename dst_t>
static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy, static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1, const sycl::nd_item<3> &item_ct1,
const uint32_t *iq1s_grid, const uint32_t *iq1s_grid) {
const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) {
const int i = item_ct1.get_group(2); const int i = item_ct1.get_group(2);
const block_iq1_s * x = (const block_iq1_s *) vx; const block_iq1_s * x = (const block_iq1_s *) vx;
@ -4712,14 +4710,14 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
const int il = tid/8; // 0...3 const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7 const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib; const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]); const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]); uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1); grid32[0] = iq1s_grid[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7]; grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
for (int j = 0; j < 4; ++j) { grid32[0] &= 0x0f0f0f0f;
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); for (int j = 0; j < 8; ++j) {
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); y[j] = d * (q[j] + delta);
} }
#else #else
assert(false); assert(false);
@ -7616,27 +7614,23 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
static __dpct_inline__ float static __dpct_inline__ float
vec_dot_iq1_s_q8_1(const void *__restrict__ vbq, vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs, const block_q8_1 *__restrict__ bq8_1, const int &iqs,
const uint32_t *iq1s_grid, const uint64_t *ksigns64) { const uint32_t *iq1s_grid) {
#if QK_K == 256 #if QK_K == 256
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
const int ib32 = iqs; const int ib32 = iqs;
const uint8_t * qs = bq1->qs + 4*ib32; const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
const int8_t * q8 = bq8_1[ib32].qs; const int * q8 = (const int *)bq8_1[ib32].qs;
int sumi = 0; int sumi = 0;
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
const uint32_t * grid = (const uint32_t *)(iq1s_grid + qs[l]); const int * grid = (const int *)(iq1s_grid + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
const uint32_t * signs = (const uint32_t *)(ksigns64 + (qs[l] >> 8)); int grid0 = grid[0] & 0x0f0f0f0f;
const int grid_l = dpct::vectorized_binary<sycl::uchar4>( int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
grid[0] ^ signs[0], signs[0], std::minus<>()); sumi = dpct::dp4a(q8[2*l+1], grid1, dpct::dp4a(q8[2*l+0], grid0, sumi));
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
grid[1] ^ signs[1], signs[1], std::minus<>());
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
q8 += 8;
} }
const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f; const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
return d * sumi; const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
const float d = d1q * (float)bq8_1[ib32].ds[0];
const float m = d1q * (float)bq8_1[ib32].ds[1];
return d * sumi + m * delta;
#else #else
assert(false); assert(false);
return 0.f; return 0.f;
@ -8456,7 +8450,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
template <int qk, int qi, typename block_q_t, int vdr> template <int qk, int qi, typename block_q_t, int vdr>
static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1, const sycl::nd_item<3> &item_ct1,
const uint32_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) { const uint32_t *iq1s_grid_ptr) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1); item_ct1.get_local_id(1);
@ -8484,7 +8478,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
(item_ct1.get_local_id(2) % (item_ct1.get_local_id(2) %
(qi / vdr)); // x block quant index when casting the quants to int (qi / vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr); tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr);
} }
// sum up partial sums and write back result // sum up partial sums and write back result
@ -10227,16 +10221,12 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
auto iq1s_grid_ptr_ct1 = iq1s_grid_gpu.get_ptr(); auto iq1s_grid_ptr_ct1 = iq1s_grid_gpu.get_ptr();
auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr();
auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr();
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_s( dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_ptr_ct1);
vx, y, item_ct1, iq1s_grid_ptr_ct1,
ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
}); });
}); });
} }
@ -10967,11 +10957,9 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
iq1s_grid_gpu.init(*stream); iq1s_grid_gpu.init(*stream);
ksigns64.init(*stream);
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
auto iq1s_grid_ptr_ct1 = iq1s_grid_gpu.get_ptr(); auto iq1s_grid_ptr_ct1 = iq1s_grid_gpu.get_ptr();
auto ksigns64_ptr_ct1 = ksigns64.get_ptr();
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
@ -10979,7 +10967,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
[[intel::reqd_sub_group_size(32)]] { [[intel::reqd_sub_group_size(32)]] {
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>( mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1, vx, vy, dst, ncols, nrows, item_ct1,
iq1s_grid_ptr_ct1, ksigns64_ptr_ct1); iq1s_grid_ptr_ct1);
}); });
}); });
} }