k_quants: change Q5_K to be type 0 when QK_K = 64

Still needs AVX2 implementation
This commit is contained in:
Iwan Kawrakow 2023-06-24 17:39:25 +03:00
parent 4f61506929
commit ccf4901334
4 changed files with 103 additions and 93 deletions

View file

@ -164,11 +164,12 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2,
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
typedef struct { typedef struct {
half d[2*QK_K/32]; // super-block scales/mins half d; // super-block scale
int8_t scales[QK_K/16]; // block scales
uint8_t qh[QK_K/8]; // quants, high bit uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K; } block_q5_K;
static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
#else #else
typedef struct { typedef struct {
half d; // super-block scale for quantized scales half d; // super-block scale for quantized scales
@ -548,10 +549,12 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
const uint8_t q = x[i].qs[tid]; const uint8_t q = x[i].qs[tid];
const int im = tid/8; // 0...3 const int im = tid/8; // 0...3
const int in = tid%8; // 0...7 const int in = tid%8; // 0...7
const int is = tid/16; // 0 or 1
const uint8_t h = x[i].qh[in] >> im; const uint8_t h = x[i].qh[in] >> im;
const float d = x[i].d;
float * y = yy + i*QK_K + tid; float * y = yy + i*QK_K + tid;
y[ 0] = (float)x[i].d[0] * ((q & 0xF) + ((h >> 0) & 1 ? 16 : 0)) - (float)x[i].d[1]; y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
y[32] = (float)x[i].d[2] * ((q >> 4) + ((h >> 4) & 1 ? 16 : 0)) - (float)x[i].d[3]; y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
#endif #endif
} }
@ -992,17 +995,16 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const uint8_t * q = x[i].qs + step; const uint8_t * q = x[i].qs + step;
const int8_t * s = x[i].scales;
const float * y = yy + i*QK_K + step; const float * y = yy + i*QK_K + step;
const half2 * d = (const half2 *)x[i].d; const float d = x[i].d;
float2 df1 = __half22float2(d[0]);
float2 df2 = __half22float2(d[1]);
float sum = 0.f; float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
const uint8_t h = x[i].qh[in+j] >> im; const uint8_t h = x[i].qh[in+j] >> im;
sum += y[j+ 0] * (df1.x * ((q[j+ 0] & 0xF) + (((h >> 0) & 1) << 4)) - df1.y) sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+ y[j+16] * (df1.x * ((q[j+16] & 0xF) + (((h >> 2) & 1) << 4)) - df1.y) + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
+ y[j+32] * (df2.x * ((q[j+ 0] >> 4) + (((h >> 4) & 1) << 4)) - df2.y) + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
+ y[j+48] * (df2.x * ((q[j+16] >> 4) + (((h >> 6) & 1) << 4)) - df2.y); + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
} }
tmp += sum; tmp += sum;
} }

View file

@ -823,7 +823,8 @@ typedef struct {
#if QK_K == 64 #if QK_K == 64
typedef struct { typedef struct {
half4 d; // super-block scales/mins half d; // super-block scales/mins
int8_t scales[QK_K/16]; // 8-bit block scales
uint8_t qh[QK_K/8]; // quants, high bit uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K; } block_q5_K;
@ -1062,20 +1063,21 @@ static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, i
#else #else
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float4 d = (float4)x[i].d; const float d = (float)x[i].d;
device const uint8_t * ql = x[i].qs; device const uint8_t * ql = x[i].qs;
device const uint8_t * qh = x[i].qh; device const uint8_t * qh = x[i].qh;
device const int8_t * sc = x[i].scales;
for (int l = 0; l < 8; ++l) { for (int l = 0; l < 8; ++l) {
y[l+ 0] = d[0] * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - d[1]; y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
y[l+ 8] = d[0] * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - d[1]; y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
y[l+16] = d[0] * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - d[1]; y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
y[l+24] = d[0] * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - d[1]; y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
y[l+32] = d[2] * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - d[3]; y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
y[l+40] = d[2] * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - d[3]; y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
y[l+48] = d[2] * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - d[3]; y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
y[l+56] = d[2] * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - d[3]; y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
} }
y += QK_K; y += QK_K;
} }
@ -1336,12 +1338,6 @@ kernel void kernel_mul_mat_q3_K_f32(
uint2 tpitg[[thread_position_in_threadgroup]], uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) { uint2 tptg[[threads_per_threadgroup]]) {
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;
const uint8_t m3 = 3;
const int8_t m4 = 4;
const int nb = ne00/QK_K; const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x; const int64_t r0 = tgpig.x;
@ -1355,6 +1351,12 @@ kernel void kernel_mul_mat_q3_K_f32(
#if QK_K == 256 #if QK_K == 256
const uint8_t m3 = 3;
const int8_t m4 = 4;
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;
const int tid = tpitg.y; // expecting 16 const int tid = tpitg.y; // expecting 16
const int ip = tid/8; // 0 or 1 const int ip = tid/8; // 0 or 1
const int il = tid/2 - 4*ip; // 0...3 const int il = tid/2 - 4*ip; // 0...3
@ -1682,18 +1684,18 @@ kernel void kernel_mul_mat_q5_K_f32(
for (int i = tpitg.y; i < nb; i += tptg.y) { for (int i = tpitg.y; i < nb; i += tptg.y) {
const float d = (float)x[i].d;
device const uint8_t * q = x[i].qs + il; device const uint8_t * q = x[i].qs + il;
device const uint8_t * h = x[i].qh + in; device const uint8_t * h = x[i].qh + in;
device const int8_t * s = x[i].scales;
device const float * y = yy + i*QK_K + il; device const float * y = yy + i*QK_K + il;
const float4 d = (float4)x[i].d;
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
const uint8_t hl = h[l] >> im; const uint8_t hl = h[l] >> im;
sumf += y[l+ 0] * (d[0] * ((q[l+ 0] & 0xF) + (hl & 0x01 ? 16 : 0)) - d[1]) sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
+ y[l+16] * (d[0] * ((q[l+16] & 0xF) + (hl & 0x04 ? 16 : 0)) - d[1]) + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
+ y[l+32] * (d[2] * ((q[l+ 0] >> 4) + (hl & 0x10 ? 16 : 0)) - d[3]) + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
+ y[l+48] * (d[2] * ((q[l+16] >> 4) + (hl & 0x40 ? 16 : 0)) - d[3]); + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
} }
} }
#endif #endif

View file

@ -792,10 +792,13 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int nb = k / QK_K; const int nb = k / QK_K;
uint8_t L[QK_K];
#if QK_K == 256 #if QK_K == 256
uint8_t L[QK_K];
float mins[QK_K/32]; float mins[QK_K/32];
float scales[QK_K/32]; float scales[QK_K/32];
#else
int8_t L[QK_K];
float scales[QK_K/16];
#endif #endif
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
@ -869,20 +872,30 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
ql += 32; ql += 32;
} }
#else #else
for (int j = 0; j < QK_K/32; ++j) { float max_scale = 0, amax = 0;
float min; for (int j = 0; j < QK_K/16; ++j) {
float scale = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &min, 5); scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
y[i].d[2*j+0] = ggml_fp32_to_fp16(scale); float abs_scale = fabsf(scales[j]);
y[i].d[2*j+1] = ggml_fp32_to_fp16(min); if (abs_scale > amax) {
amax = abs_scale;
max_scale = scales[j];
} }
for (int j = 0; j < QK_K/32; ++j) { }
const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]);
float iscale = -128.f/max_scale;
for (int j = 0; j < QK_K/16; ++j) {
int l = nearest_int(iscale*scales[j]);
y[i].scales[j] = MAX(-128, MIN(127, l));
}
y[i].d = ggml_fp32_to_fp16(1/iscale);
for (int j = 0; j < QK_K/16; ++j) {
const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
if (!d) continue; if (!d) continue;
const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]); for (int ii = 0; ii < 16; ++ii) {
for (int ii = 0; ii < 32; ++ii) { int l = nearest_int(x[16*j + ii]/d);
int l = nearest_int((x[32*j + ii] + dm)/d); l = MAX(-16, MIN(15, l));
l = MAX(0, MIN(31, l)); L[16*j + ii] = l + 16;
L[32*j + ii] = l;
} }
} }
@ -938,17 +951,17 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
u1 <<= 2; u2 <<= 2; u1 <<= 2; u2 <<= 2;
} }
#else #else
float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]); float d = ggml_fp16_to_fp32(x[i].d);
float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]); const int8_t * restrict s = x[i].scales;
for (int l = 0; l < 8; ++l) { for (int l = 0; l < 8; ++l) {
y[l+ 0] = d1 * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - m1; y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
y[l+ 8] = d1 * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - m1; y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
y[l+16] = d1 * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - m1; y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
y[l+24] = d1 * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - m1; y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
y[l+32] = d2 * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - m2; y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
y[l+40] = d2 * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - m2; y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
y[l+48] = d2 * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - m2; y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
y[l+56] = d2 * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - m2; y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
} }
y += QK_K; y += QK_K;
#endif #endif
@ -2751,19 +2764,12 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
float sumf = 0; float sumf = 0;
float32x4_t acc1 = vdupq_n_f32(0.f);
float32x4_t acc2 = vdupq_n_f32(0.f);
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const float16x4_t s16 = vld1_f16(x[i].d); const float d = y[i].d * (float)x[i].d;
const float32x4_t s32 = vmulq_n_f32(vcvt_f32_f16(s16), y[i].d); const int8_t * sc = x[i].scales;
//const int16x4_t bi16 = vld1_s16(y[i].bsums);
//const int32x4_t bi32 = vmovl_s16(vpadd_s16(bi16, bi16)); sumf -= 16.f * d * (sc[0] * y[i].bsums[0] + sc[1] * y[i].bsums[1] + sc[2] * y[i].bsums[2] + sc[3] * y[i].bsums[3]);
//const float32x4_t bf32 = vcvtq_f32_s32(bi32);
//sumf -= (vgetq_lane_f32(s32, 1) * vgetq_lane_f32(bf32, 0) + vgetq_lane_f32(s32, 3) * vgetq_lane_f32(bf32, 1));
// The above is slightly slower than just this:
sumf -= (vgetq_lane_f32(s32, 1) * (y[i].bsums[0] + y[i].bsums[1]) + vgetq_lane_f32(s32, 3) * (y[i].bsums[2] + y[i].bsums[3]));
const uint8_t * restrict q5 = x[i].qs; const uint8_t * restrict q5 = x[i].qs;
const uint8_t * restrict qh = x[i].qh; const uint8_t * restrict qh = x[i].qh;
@ -2787,34 +2793,35 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])), int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
vgetq_lane_f32(s32, 0)); int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])), int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
vgetq_lane_f32(s32, 2)); int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
#else #else
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
const int16x8_t p01_16 = vaddq_s16(p0, p1); int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
const int32x4_t p01_32 = vaddq_s32(vmovl_s16(vget_low_s16(p01_16)), vmovl_s16(vget_high_s16(p01_16)));
acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(p01_32), vgetq_lane_f32(s32, 0));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
const int16x8_t p02_16 = vaddq_s16(p2, p3); sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
const int32x4_t p02_32 = vaddq_s32(vmovl_s16(vget_low_s16(p02_16)), vmovl_s16(vget_high_s16(p02_16)));
acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(p02_32), vgetq_lane_f32(s32, 2)); sumf += d*sumi;
#endif #endif
} }
*s = vaddvq_f32(vaddq_f32(acc1, acc2)) + sumf; *s = sumf;
#elif defined __AVX2__ #elif defined z__AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF); const __m256i m4 = _mm256_set1_epi8(0xF);
const __m256i m1 = _mm256_set1_epi16(1); const __m256i m1 = _mm256_set1_epi16(1);
@ -2884,19 +2891,17 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
} }
for (int is = 0; is < 8; ++is) { for (int is = 0; is < 8; ++is) {
uint8_t m = 1 << is; uint8_t m = 1 << is;
for (int l = 0; l < 8; ++l) a[8*is + l] += (hm[l] & m ? 16 : 0); for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
} }
sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); const int8_t * restrict sc = x[i].scales;
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/16; ++j) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); const float dl = d * sc[j];
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
q8 += 16; a += 16; q8 += 16; a += 16;
for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
q8 += 16; a += 16;
for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[8+l]);
} }
} }
for (int l = 0; l < 8; ++l) sumf += sums[l]; for (int l = 0; l < 8; ++l) sumf += sums[l];

View file

@ -80,11 +80,12 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/
// Effectively 5.5 bits per weight // Effectively 5.5 bits per weight
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
typedef struct { typedef struct {
ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins ggml_fp16_t d; // super-block scale
int8_t scales[QK_K/16]; // 8-bit block scales
uint8_t qh[QK_K/8]; // quants, high bit uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K; } block_q5_K;
static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
#else #else
typedef struct { typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales ggml_fp16_t d; // super-block scale for quantized scales