iq3s_mult: ARM and Metal

This commit is contained in:
Iwan Kawrakow 2024-03-03 11:30:01 +02:00
parent b6402fa757
commit 5b9c8785fa
2 changed files with 46 additions and 14 deletions

View file

@ -2546,7 +2546,12 @@ typedef struct {
uint8_t signs[QK_K/8];
uint8_t scales[IQ3S_N_SCALE];
} block_iq3_s;
#ifdef IQ3S_SLOW_MULT
#define IQ3S_MULTIPLIER 190842953
#else
//#define IQ3S_MULTIPLIER 898886
#define IQ3S_MULTIPLIER 842866
#endif
typedef struct {
half d;
@ -4691,15 +4696,21 @@ void kernel_mul_mv_iq3_s_f32_impl(
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
{
uint32_t aux32;
thread int8_t * q = (thread int8_t *)&aux32;
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
#ifdef IQ3S_SLOW_MULT
uint32_t aux32;
thread int8_t * q = (thread int8_t *)&aux32;
for (int i = 0; i < nval; ++i) {
aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f;
for (int k = 0; k < 4; ++k) q[k] = 2*((q[k]-1)/2) + 1;
values[pos + i] = aux32;
}
#else
for (int i = 0; i < nval; ++i) {
values[pos + i] = ((IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f) | 0x01010101;
}
#endif
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@ -4733,17 +4744,16 @@ void kernel_mul_mv_iq3_s_f32_impl(
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
//aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f;
//aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f;
//threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
//threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + qs[2*l+0] +
select(0, 256, qh[0] & kmask_iq2xs[2*l+0]));
threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + qs[2*l+1] +
select(0, 256, qh[0] & kmask_iq2xs[2*l+1]));
// This is slower than pre-computing the grid in shared memory and loading from there
//aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101;
//aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101;
//for (int j = 0; j < 4; ++j) {
// sum[0] += yl[8*l + j + 0] * grid[j+0] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
// sum[1] += yl[8*l + j + 4] * grid[j+4] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
//}
threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
//sum[0] += yl[8*l + j + 0] * (2*((grid[j+0] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
//sum[1] += yl[8*l + j + 4] * (2*((grid[j+4] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
}
@ -5657,6 +5667,7 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
uint32_t aux32[2];
thread const int8_t * grid = (thread const int8_t *)aux32;
#ifdef IQ3S_SLOW)MULT
aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f;
aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f;
for (int i = 0; i < 4; ++i) {
@ -5669,6 +5680,20 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
reg[2][i] = dl * (2*((grid[i+0]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
}
#else
aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f) | 0x01010101;
aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f) | 0x01010101;
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid[i+0] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid[i+4] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f) | 0x01010101;
aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f) | 0x01010101;
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid[i+0] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid[i+4] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
}
#endif
}
template <typename type4x4>