iq3_xs: tiny Metal speed improvement
This commit is contained in:
parent
1777825550
commit
87038fe198
1 changed files with 12 additions and 12 deletions
|
@ -4500,7 +4500,7 @@ void kernel_mul_mv_iq3_xs_f32_impl(
|
|||
|
||||
device const block_iq3_xs * xr = x + ibl;
|
||||
device const uint8_t * qs = xr->qs + 8 * ib;
|
||||
device const uint8_t * qh = xr->qh;
|
||||
device const uint8_t * qh = xr->qh + ib;
|
||||
device const uint8_t * sc = xr->scales + (ib/2);
|
||||
device const uint8_t * signs = xr->signs + 4 * ib;
|
||||
device const half * dh = &xr->d;
|
||||
|
@ -4512,8 +4512,8 @@ void kernel_mul_mv_iq3_xs_f32_impl(
|
|||
|
||||
float2 sum = {0};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[ib] << (8-2*l)) & 256)));
|
||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[ib] << (7-2*l)) & 256)));
|
||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||
|
@ -5173,20 +5173,20 @@ void dequantize_iq3_xs(device const block_iq3_xs * xb, short il, thread type4x4
|
|||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint8_t * qs = xb->qs + 8*ib32;
|
||||
device const uint8_t * qh = xb->qh;
|
||||
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
||||
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
||||
const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh[ib32] << (8-4*il)) & 256)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh[ib32] << (7-4*il)) & 256)));
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * grid1[i] * (signs[0] & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
||||
reg[1][i] = dl * grid2[i] * (signs[0] & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
||||
reg[0][i] = dl * grid1[i] * (signs[0] & kmask_iq2xs[i+0] ? -1 : 1);
|
||||
reg[1][i] = dl * grid2[i] * (signs[0] & kmask_iq2xs[i+4] ? -1 : 1);
|
||||
}
|
||||
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh[ib32] << (6-4*il)) & 256)));
|
||||
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh[ib32] << (5-4*il)) & 256)));
|
||||
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
||||
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[2][i] = dl * grid1[i] * (signs[1] & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
||||
reg[3][i] = dl * grid2[i] * (signs[1] & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
||||
reg[2][i] = dl * grid1[i] * (signs[1] & kmask_iq2xs[i+0] ? -1 : 1);
|
||||
reg[3][i] = dl * grid2[i] * (signs[1] & kmask_iq2xs[i+4] ? -1 : 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue