fix quantize bug

This commit is contained in:
luoyu-intel 2024-06-28 09:40:41 +08:00
parent 4cd48c7cfc
commit 61f0cd58dc

View file

@ -358,10 +358,11 @@ static void pad_f32(const float *x, float *dst, const int ne0, const int ne00,
} }
} }
template<int QUANT_BLOCK_TILE>
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int ix = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
if (ix >= kx_padded) { if (ix >= kx_padded) {
return; return;
@ -376,23 +377,39 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
const int ib = i_padded / QK8_1; // block index const int ib = i_padded / QK8_1; // block index
const int iqs = i_padded % QK8_1; // quant index const int iqs = i_padded % QK8_1; // quant index
typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
float amax = sycl::fabs((float)xi); TC zeros;
float sum = xi; TQ qzeros;
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int i = 0; i < QUANT_BLOCK_TILE; i++)
amax = sycl::fmax(amax, dpct::permute_sub_group_by_xor( {
item_ct1.get_sub_group(), amax, mask)); zeros[i] = 0.f;
sum += qzeros[i] = 0;
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), sum, mask);
} }
const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
float sum = xi[0];
float amax = sycl::fabs(xi[0]);
#pragma unroll
for (int i = 1; i < QUANT_BLOCK_TILE; i++)
{
sum += xi[i];
amax = sycl::fmax(sycl::fabs(xi[i]), amax);
}
sum = warp_reduce_sum(sum, item_ct1);
amax = warp_reduce_max(amax, item_ct1);
const float d = amax / 127; const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : sycl::round(xi / d); TQ q = qzeros;
if (amax != 0.0f)
{
#pragma unroll
for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
q[i] = sycl::round(xi[i] / d);
}
}
y[ib].qs[iqs] = q; *(TQ *)&y[ib].qs[iqs] = q;
if (iqs > 0) { if (iqs > 0) {
return; return;
@ -1487,7 +1504,9 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
queue_ptr stream) { queue_ptr stream) {
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
const sycl::range<3> num_blocks(1, ky, block_num_x); const sycl::range<3> num_blocks(1, ky, block_num_x);
const sycl::range<3> block_size(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE); int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
static_assert(QK8_1 % WARP_SIZE == 0);
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
@ -1495,7 +1514,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(num_blocks * block_size, block_size), sycl::nd_range<3>(num_blocks * block_size, block_size),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
quantize_q8_1(x, vy, kx, kx_padded, item_ct1); quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
}); });
} }
} }
@ -3867,7 +3886,6 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
int64_t min_compute_capability = INT_MAX; int64_t min_compute_capability = INT_MAX;
if (split) { if (split) {