fix quantize bug
This commit is contained in:
parent
4cd48c7cfc
commit
61f0cd58dc
1 changed files with 35 additions and 17 deletions
|
@ -358,10 +358,11 @@ static void pad_f32(const float *x, float *dst, const int ne0, const int ne00,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<int QUANT_BLOCK_TILE>
|
||||||
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
|
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int ix = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2);
|
item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
|
||||||
|
|
||||||
if (ix >= kx_padded) {
|
if (ix >= kx_padded) {
|
||||||
return;
|
return;
|
||||||
|
@ -376,23 +377,39 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
|
||||||
|
|
||||||
const int ib = i_padded / QK8_1; // block index
|
const int ib = i_padded / QK8_1; // block index
|
||||||
const int iqs = i_padded % QK8_1; // quant index
|
const int iqs = i_padded % QK8_1; // quant index
|
||||||
|
typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
|
||||||
const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
|
typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
|
||||||
float amax = sycl::fabs((float)xi);
|
TC zeros;
|
||||||
float sum = xi;
|
TQ qzeros;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int i = 0; i < QUANT_BLOCK_TILE; i++)
|
||||||
amax = sycl::fmax(amax, dpct::permute_sub_group_by_xor(
|
{
|
||||||
item_ct1.get_sub_group(), amax, mask));
|
zeros[i] = 0.f;
|
||||||
sum +=
|
qzeros[i] = 0;
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), sum, mask);
|
|
||||||
}
|
}
|
||||||
|
const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
|
||||||
|
float sum = xi[0];
|
||||||
|
float amax = sycl::fabs(xi[0]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 1; i < QUANT_BLOCK_TILE; i++)
|
||||||
|
{
|
||||||
|
sum += xi[i];
|
||||||
|
amax = sycl::fmax(sycl::fabs(xi[i]), amax);
|
||||||
|
}
|
||||||
|
sum = warp_reduce_sum(sum, item_ct1);
|
||||||
|
amax = warp_reduce_max(amax, item_ct1);
|
||||||
|
|
||||||
const float d = amax / 127;
|
const float d = amax / 127;
|
||||||
const int8_t q = amax == 0.0f ? 0 : sycl::round(xi / d);
|
TQ q = qzeros;
|
||||||
|
if (amax != 0.0f)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
|
||||||
|
q[i] = sycl::round(xi[i] / d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
y[ib].qs[iqs] = q;
|
*(TQ *)&y[ib].qs[iqs] = q;
|
||||||
|
|
||||||
if (iqs > 0) {
|
if (iqs > 0) {
|
||||||
return;
|
return;
|
||||||
|
@ -1487,7 +1504,9 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
|
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
|
||||||
const sycl::range<3> num_blocks(1, ky, block_num_x);
|
const sycl::range<3> num_blocks(1, ky, block_num_x);
|
||||||
const sycl::range<3> block_size(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE);
|
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
|
||||||
|
static_assert(QK8_1 % WARP_SIZE == 0);
|
||||||
|
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -1495,7 +1514,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(num_blocks * block_size, block_size),
|
sycl::nd_range<3>(num_blocks * block_size, block_size),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
quantize_q8_1(x, vy, kx, kx_padded, item_ct1);
|
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3867,7 +3886,6 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||||
|
|
||||||
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
|
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
|
||||||
|
|
||||||
int64_t min_compute_capability = INT_MAX;
|
int64_t min_compute_capability = INT_MAX;
|
||||||
|
|
||||||
if (split) {
|
if (split) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue