iq1_m: use common definition of iq1m_scale_t

This commit is contained in:
Iwan Kawrakow 2024-03-26 06:16:02 +02:00
parent b68f32b391
commit 9a5786e939
5 changed files with 13 additions and 30 deletions

View file

@ -385,6 +385,12 @@ typedef struct {
} block_iq1_m;
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
// Used by IQ1_M quants
typedef union {
ggml_half f16;
uint16_t u16;
} iq1m_scale_t;
// Non-linear quants
#define QK4_NL 32
typedef struct {

View file

@ -501,11 +501,6 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
}
typedef union {
half f16;
uint16_t u16;
} iq1m_scale_t;
template<typename dst_t>
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {

View file

@ -1197,10 +1197,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
q8 += 8;
}
#endif
typedef union {
half f16;
uint16_t u16;
} iq1m_scale_t;
iq1m_scale_t scale;
const uint16_t * sc = (const uint16_t *)bq1->scales;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);

View file

@ -4456,11 +4456,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
}
}
typedef union {
half f16;
uint16_t u16;
} iq1m_scale_t;
void kernel_mul_mv_iq1_m_f32_impl(
device const void * src0,
device const float * src1,

View file

@ -3474,11 +3474,6 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
}
}
typedef union {
ggml_fp16_t fp16;
uint16_t u16;
} iq1m_scale_t;
void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@ -3492,7 +3487,7 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in
const uint16_t * sc = (const uint16_t *)x[i].scales;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = GGML_FP16_TO_FP32(scale.fp16);
const float d = GGML_FP16_TO_FP32(scale.f16);
const uint8_t * qs = x[i].qs;
const uint8_t * qh = x[i].qh;
@ -9761,6 +9756,8 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const int nb = n / QK_K;
iq1m_scale_t scale;
#if defined __ARM_NEON
const int32x4_t mask = vdupq_n_s32(0x7);
@ -9776,8 +9773,6 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
ggml_int8x16x4_t q1b;
ggml_int8x16x4_t q8b;
iq1m_scale_t scale;
uint32_t aux32;
const uint8_t * aux8 = (const uint8_t *)&aux32;
@ -9828,7 +9823,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
}
sumf += y[i].d * GGML_FP16_TO_FP32(scale.fp16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
}
*s = sumf;
@ -9838,8 +9833,6 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const __m256i mask = _mm256_set1_epi16(0x7);
const __m256i mone = _mm256_set1_epi16(1);
iq1m_scale_t scale;
__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
@ -9894,7 +9887,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
qs += 8; qh += 4;
}
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.fp16));
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
@ -9904,8 +9897,6 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
#else
iq1m_scale_t scale;
int sum1[2], sum2[2], delta[4];
float sumf = 0;
@ -9944,7 +9935,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
qh += 2;
}
sumf += GGML_FP16_TO_FP32(scale.fp16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
}
*s = sumf;
@ -12198,7 +12189,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
}
}
if (sumq2_f > 0) d = sumqx_f/sumq2_f;
s.fp16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
sc[0] |= ((s.u16 & 0x000f) << 12);
sc[1] |= ((s.u16 & 0x00f0) << 8);
sc[2] |= ((s.u16 & 0x0f00) << 4);