iq1s_blocks16: speedup Metal by packing codebook into uint32_t's

This commit is contained in:
Iwan Kawrakow 2024-03-09 07:05:20 +01:00
parent 8561139a48
commit d3da9d1617

View file

@ -4349,8 +4349,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
for (int ib32 = ix; ib32 < nb32; ib32 += 32) { for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
float sumy = 0;
for (int i = 0; i < 32; ++i) { for (int i = 0; i < 32; ++i) {
yl[i] = y4[i]; yl[i] = y4[i];
sumy += yl[i];
} }
const int ibl = ib32 / (QK_K / 32); const int ibl = ib32 / (QK_K / 32);
@ -4363,16 +4365,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
for (int row = 0; row < N_DST; row++) { for (int row = 0; row < N_DST; row++) {
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700))); constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)));
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 5) & 0x700))); constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 5) & 0x700)));
constant int8_t * grid3 = (constant int8_t *)(iq1s_grid + (qs[2] | ((qh[0] << 2) & 0x700))); constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid + (qs[2] | ((qh[0] << 2) & 0x700)));
constant int8_t * grid4 = (constant int8_t *)(iq1s_grid + (qs[3] | ((qh[0] >> 1) & 0x700))); constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid + (qs[3] | ((qh[0] >> 1) & 0x700)));
float sum = 0; float sum = 0;
for (int j = 0; j < 8; ++j) { for (int j = 0; j < 4; ++j) {
sum += yl[j+0] * grid1[j] + yl[j+8] * grid2[j] + yl[j+16] * grid3[j] + yl[j+24] * grid4[j]; sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
} }
sumf[row] += (float)dh[0] * sum * (2*(qh[0] >> 12) + 1); sumf[row] += (float)dh[0] * (sum - sumy) * (2*(qh[0] >> 12) + 1);
dh += nb*sizeof(block_iq1_s)/2; dh += nb*sizeof(block_iq1_s)/2;
qs += nb*sizeof(block_iq1_s); qs += nb*sizeof(block_iq1_s);
@ -5072,11 +5077,13 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint16_t * qh = xb->qh; device const uint16_t * qh = xb->qh;
const float dl = d * (2*(qh[ib32] >> 12) + 1); const float dl = d * (2*(qh[ib32] >> 12) + 1);
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8))); constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8)));
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8))); constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8)));
for (int i = 0; i < 8; ++i) { for (int i = 0; i < 4; ++i) {
reg[i/4+0][i%4] = dl * grid1[i]; reg[0][i] = dl * (grid1[i] & 0xf) - dl;
reg[i/4+2][i%4] = dl * grid2[i]; reg[1][i] = dl * (grid1[i] >> 4) - dl;
reg[2][i] = dl * (grid2[i] & 0xf) - dl;
reg[3][i] = dl * (grid2[i] >> 4) - dl;
} }
} }