metal : fix accuracy of dequantization kernels
ggml-ci
This commit is contained in:
parent
bc014485be
commit
694449f0e8
2 changed files with 15 additions and 14 deletions
|
@ -3521,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|||
|
||||
template <typename type4x4>
|
||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||
const half d = xb->d;
|
||||
const half min = xb->dmin;
|
||||
const float d = xb->d;
|
||||
const float min = xb->dmin;
|
||||
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
||||
half dl, ml;
|
||||
float dl, ml;
|
||||
uint8_t sc = xb->scales[il];
|
||||
|
||||
#if QK_K == 256
|
||||
|
@ -3594,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|||
q = q + (il/4) * 32 + 16 * (il&1);
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const half min = xb->dmin;
|
||||
const half dl = d * sc[0];
|
||||
const half ml = min * sc[1];
|
||||
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const float min = xb->dmin;
|
||||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
#else
|
||||
q = q + 16 * (il&1);
|
||||
device const uint8_t * s = xb->scales;
|
||||
|
@ -3624,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|||
uint8_t ul = 1 << (il/2);
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const half min = xb->dmin;
|
||||
const half dl = d * sc[0];
|
||||
const half ml = min * sc[1];
|
||||
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const float min = xb->dmin;
|
||||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
|
||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||
const half qh_val = il<2 ? 16.h : 256.h;
|
||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||
const float qh_val = il<2 ? 16.f : 256.f;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
||||
}
|
||||
|
|
|
@ -432,9 +432,10 @@ struct test_case {
|
|||
if (err > ud->max_err) {
|
||||
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
|
||||
//for (int i = 0; i < f1.size(); i++) {
|
||||
// printf("(%f, %f) ", f1[i], f2[i]);
|
||||
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||
//}
|
||||
//printf("\n");
|
||||
//exit(1);
|
||||
ud->ok = false;
|
||||
}
|
||||
return true;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue