iq3_xs: tiny Metal speed improvement
This commit is contained in:
parent
87038fe198
commit
4d5feebeb6
1 changed files with 6 additions and 6 deletions
|
@ -4515,8 +4515,8 @@ void kernel_mul_mv_iq3_xs_f32_impl(
|
||||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
||||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
||||||
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sumf[row] += d * (sum[0] + sum[1]);
|
sumf[row] += d * (sum[0] + sum[1]);
|
||||||
|
@ -5179,14 +5179,14 @@ void dequantize_iq3_xs(device const block_iq3_xs * xb, short il, thread type4x4
|
||||||
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
||||||
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
reg[0][i] = dl * grid1[i] * (signs[0] & kmask_iq2xs[i+0] ? -1 : 1);
|
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
||||||
reg[1][i] = dl * grid2[i] * (signs[0] & kmask_iq2xs[i+4] ? -1 : 1);
|
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
||||||
}
|
}
|
||||||
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
||||||
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
reg[2][i] = dl * grid1[i] * (signs[1] & kmask_iq2xs[i+0] ? -1 : 1);
|
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
||||||
reg[3][i] = dl * grid2[i] * (signs[1] & kmask_iq2xs[i+4] ? -1 : 1);
|
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue