iq1s_blocks16: very slightly faster TG on Metal
Still pathetic at 37 t/s
This commit is contained in:
parent
15acc7923b
commit
8561139a48
1 changed files with 12 additions and 11 deletions
|
@ -4338,19 +4338,18 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[16];
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
const int ix = tiisg/2;
|
||||
const int il = tiisg%2;
|
||||
const int ix = tiisg;
|
||||
|
||||
device const float * y4 = y + 32 * ix + 16 * il;
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
yl[i] = y4[i];
|
||||
}
|
||||
|
||||
|
@ -4358,18 +4357,20 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
const int ib = ib32 % (QK_K / 32);
|
||||
|
||||
device const block_iq1_s * xr = x + ibl;
|
||||
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
||||
device const uint8_t * qs = xr->qs + 4 * ib;
|
||||
device const uint16_t * qh = xr->qh + ib;
|
||||
device const half * dh = &xr->d;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | (((qh[0] >> (6*il+0)) & 7) << 8)));
|
||||
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | (((qh[0] >> (6*il+3)) & 7) << 8)));
|
||||
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 5) & 0x700)));
|
||||
constant int8_t * grid3 = (constant int8_t *)(iq1s_grid + (qs[2] | ((qh[0] << 2) & 0x700)));
|
||||
constant int8_t * grid4 = (constant int8_t *)(iq1s_grid + (qs[3] | ((qh[0] >> 1) & 0x700)));
|
||||
|
||||
float sum = 0;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
sum += yl[j+0] * grid1[j] + yl[j+8] * grid2[j];
|
||||
sum += yl[j+0] * grid1[j] + yl[j+8] * grid2[j] + yl[j+16] * grid3[j] + yl[j+24] * grid4[j];
|
||||
}
|
||||
sumf[row] += (float)dh[0] * sum * (2*(qh[0] >> 12) + 1);
|
||||
|
||||
|
@ -4378,7 +4379,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
qh += nb*sizeof(block_iq1_s)/2;
|
||||
}
|
||||
|
||||
y4 += 16 * 32;
|
||||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue