iq1s_blocks16: speedup Metal by packing codebook into uint32_t's
This commit is contained in:
parent
8561139a48
commit
d3da9d1617
1 changed files with 19 additions and 12 deletions
|
@ -4349,8 +4349,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
float sumy = 0;
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
yl[i] = y4[i];
|
||||
sumy += yl[i];
|
||||
}
|
||||
|
||||
const int ibl = ib32 / (QK_K / 32);
|
||||
|
@ -4363,16 +4365,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 5) & 0x700)));
|
||||
constant int8_t * grid3 = (constant int8_t *)(iq1s_grid + (qs[2] | ((qh[0] << 2) & 0x700)));
|
||||
constant int8_t * grid4 = (constant int8_t *)(iq1s_grid + (qs[3] | ((qh[0] >> 1) & 0x700)));
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 5) & 0x700)));
|
||||
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid + (qs[2] | ((qh[0] << 2) & 0x700)));
|
||||
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid + (qs[3] | ((qh[0] >> 1) & 0x700)));
|
||||
|
||||
float sum = 0;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
sum += yl[j+0] * grid1[j] + yl[j+8] * grid2[j] + yl[j+16] * grid3[j] + yl[j+24] * grid4[j];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
||||
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
|
||||
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
||||
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
|
||||
}
|
||||
sumf[row] += (float)dh[0] * sum * (2*(qh[0] >> 12) + 1);
|
||||
sumf[row] += (float)dh[0] * (sum - sumy) * (2*(qh[0] >> 12) + 1);
|
||||
|
||||
dh += nb*sizeof(block_iq1_s)/2;
|
||||
qs += nb*sizeof(block_iq1_s);
|
||||
|
@ -5072,11 +5077,13 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
|||
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||
device const uint16_t * qh = xb->qh;
|
||||
const float dl = d * (2*(qh[ib32] >> 12) + 1);
|
||||
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8)));
|
||||
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8)));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[i/4+0][i%4] = dl * grid1[i];
|
||||
reg[i/4+2][i%4] = dl * grid2[i];
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * (grid1[i] & 0xf) - dl;
|
||||
reg[1][i] = dl * (grid1[i] >> 4) - dl;
|
||||
reg[2][i] = dl * (grid2[i] & 0xf) - dl;
|
||||
reg[3][i] = dl * (grid2[i] >> 4) - dl;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue