iq2_xxs: Metal dot product now works

We have
PP-512 = 475 t/s
TG-128 = 47.3 t/s

Not the greatest performance, but not complete garbage either.
This commit is contained in:
Iwan Kawrakow 2024-01-03 18:59:14 +01:00
parent d383f00eea
commit dd29610153

View file

@ -3601,10 +3601,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
device const float * y4 = y + 32 * ix;
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
thread uint16_t * aux16 = (thread uint16_t *)aux32;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
@ -3620,13 +3616,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
const float db = xr->d;
device const uint16_t * q2 = xr->qs + 4 * ib;
for (int i = 0; i < 4; ++i) aux16[i] = q2[i];
const float d = db * (0.5f + (aux32[1] >> 28));
device const uint8_t * aux8 = (device const uint8_t *)q2;
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = db * (0.5f + (aux32 >> 28));
float sum = 0;
for (int l = 0; l < 4; ++l) {
constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[l]);
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}