add iq4 non linear placeholder

This commit is contained in:
abhilash1910 2024-03-28 01:09:28 -07:00
parent e190f1fca6
commit 9489fc379e

View file

@ -4581,6 +4581,24 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
} }
template<typename dst_t>
static void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2);
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);;
const int tid = item_ct1.get_local_id(2);
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il;
const float d = (float)x[ib].d;
for (int j = 0; j < 4; ++j) {
y[j+0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
}
}
/* /*
DPCT1110:4: The total declared local variable size in device function DPCT1110:4: The total declared local variable size in device function
dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
@ -7497,6 +7515,45 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
#endif #endif
} }
static __dpct_inline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values,
int & val1, int & val2) {
uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
aux32 = q4 & 0x0f0f0f0f;
uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
val1 = v1 | (v2 << 16);
aux32 = (q4 >> 4) & 0x0f0f0f0f;
v1 = values[q8[0]] | (values[q8[1]] << 8);
v2 = values[q8[2]] | (values[q8[3]] << 8);
val2 = v1 | (v2 << 16);
}
static __dpct_inline__ float
vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
int v1, v2;
int sumi1 = 0, sumi2 = 0;
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
get_int_from_table_16(aux, values, v1, v2);
sumi1 = dpct::dp4a(v1, q8[l+0], sumi1);
sumi2 = dpct::dp4a(v2, q8[l+4], sumi2);
}
const float d = (float)bq->d * __low2float(bq8_1->ds);
return d * (sumi1 + sumi2);
}
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr, int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
vec_dot_q_mul_mat_sycl_t vec_dot> vec_dot_q_mul_mat_sycl_t vec_dot>
@ -8353,6 +8410,52 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
} }
} }
template <int qk, int qi, typename block_q_t, int vdr>
static void mul_mat_vec_q_iq4_nl_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
if (row >= nrows) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
// partial sum for each thread
float tmp = 0.0f;
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
const int iqs =
vdr *
(item_ct1.get_local_id(2) %
(qi / vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
if (item_ct1.get_local_id(2) == 0) {
dst[row] = tmp;
}
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel> template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
@ -10096,6 +10199,26 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
} }
} }
template <typename dst_t>
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
dpct::queue_ptr stream) {
const int nb = (k + QK_K - 1) / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_nl(
vx, y, item_ct1);
});
});
}
}
template <typename src_t, typename dst_t> template <typename src_t, typename dst_t>
static void convert_unary_sycl(const void *__restrict__ vx, static void convert_unary_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int k, dst_t *__restrict__ y, const int k,
@ -10150,6 +10273,8 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
return dequantize_row_iq3_s_sycl; return dequantize_row_iq3_s_sycl;
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl; return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_F32: case GGML_TYPE_F32:
return convert_unary_sycl<float>; return convert_unary_sycl<float>;
default: default:
@ -10194,6 +10319,8 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
return dequantize_row_iq3_s_sycl; return dequantize_row_iq3_s_sycl;
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl; return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_F16: case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>; return convert_unary_sycl<sycl::half>;
default: default:
@ -13612,6 +13739,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
return max_compute_capability >= VER_GEN9 ? 128 : 64; return max_compute_capability >= VER_GEN9 ? 128 : 64;
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S: