Slightly faster Q3_K and Q5_K on metal
This commit is contained in:
parent
cf9b08485c
commit
ec13de521c
1 changed files with 38 additions and 28 deletions
|
@ -1140,7 +1140,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
float yl[16];
|
float yl[16];
|
||||||
|
|
||||||
const uint16_t kmask1 = 0x0303;
|
const uint16_t kmask1 = 0x3030;
|
||||||
const uint16_t kmask2 = 0x0f0f;
|
const uint16_t kmask2 = 0x0f0f;
|
||||||
|
|
||||||
const int tid = tiisg/2;
|
const int tid = tiisg/2;
|
||||||
|
@ -1155,10 +1155,9 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
const uint16_t m2 = m1 << 8;
|
const uint16_t m2 = m1 << 8;
|
||||||
|
|
||||||
const int shift = 2*il;
|
const int shift = 2*il;
|
||||||
const uint16_t qm1 = 0x0003 << shift;
|
const int32_t qm1 = 0x0003 << shift;
|
||||||
const uint16_t qm2 = 0x0300 << shift;
|
const int32_t qm2 = 0x0300 << shift;
|
||||||
const int32_t v1 = 4 << shift;
|
const float v1 = 4 << shift;
|
||||||
const int32_t v2 = 1024 << shift;
|
|
||||||
|
|
||||||
const uint16_t s_shift1 = 4*ip;
|
const uint16_t s_shift1 = 4*ip;
|
||||||
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
||||||
|
@ -1171,7 +1170,10 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
device const float * y1 = yy + ix*QK_K + y_offset;
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
||||||
|
|
||||||
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
uint16_t scales16;
|
||||||
|
thread const int8_t * scales = (thread const int8_t *)&scales16;
|
||||||
|
|
||||||
|
float sumf1[2] = {0.f};
|
||||||
for (int i = ix; i < nb; i += 2) {
|
for (int i = ix; i < nb; i += 2) {
|
||||||
|
|
||||||
for (int l = 0; l < 8; ++l) {
|
for (int l = 0; l < 8; ++l) {
|
||||||
|
@ -1187,27 +1189,27 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
|
||||||
const float d_all = (float)dh[0];
|
const float d_all = (float)dh[0];
|
||||||
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
scales16 = ((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) << 4) & kmask1);
|
||||||
|
|
||||||
float s1 = 0, s2 = 0;
|
float s1 = 0, s2 = 0, s3 = 0;
|
||||||
for (int l = 0; l < n; l += 2) {
|
for (int l = 0; l < n; l += 2) {
|
||||||
const uint16_t qs = q[l/2];
|
const int32_t qs = q[l/2];
|
||||||
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
s1 += yl[l+0] * (qs & qm1);
|
||||||
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
s2 += yl[l+1] * (qs & qm2);
|
||||||
|
s3 += ((h[l/2] & m1) ? 0.f : yl[l+0]) + ((h[l/2] & m2) ? 0.f : yl[l+1]);
|
||||||
}
|
}
|
||||||
float d = d_all * (s1 + 1.f/256.f * s2);
|
float d = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||||
sumf1[row] += d * scales[0];
|
sumf1[row] += d * (scales[0] - 32);
|
||||||
sumf2[row] += d;
|
|
||||||
|
|
||||||
s1 = s2 = 0;
|
s1 = s2 = s3 = 0;
|
||||||
for (int l = 0; l < n; l += 2) {
|
for (int l = 0; l < n; l += 2) {
|
||||||
const uint16_t qs = q[l/2+8];
|
const int32_t qs = q[l/2+8];
|
||||||
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
s1 += yl[l+8] * (qs & qm1);
|
||||||
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
s2 += yl[l+9] * (qs & qm2);
|
||||||
|
s3 += ((h[l/2+8] & m1) ? 0.f : yl[l+8]) + ((h[l/2+8] & m2) ? 0.f : yl[l+9]);
|
||||||
}
|
}
|
||||||
d = d_all * (s1 + 1.f/256.f * s2);
|
d = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||||
sumf1[row] += d * scales[1];
|
sumf1[row] += d * (scales[1] - 32);
|
||||||
sumf2[row] += d;
|
|
||||||
|
|
||||||
q += step;
|
q += step;
|
||||||
h += step;
|
h += step;
|
||||||
|
@ -1221,7 +1223,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
const float sumf = sumf1[row] / (1 << shift);
|
||||||
const float tot = simd_sum(sumf);
|
const float tot = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
||||||
|
@ -1579,17 +1581,25 @@ kernel void kernel_mul_mat_q5_K_f32(
|
||||||
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
||||||
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
||||||
|
|
||||||
float4 acc = {0.f, 0.f, 0.f, 0.f};
|
float4 acc1 = {0.f};
|
||||||
|
float4 acc2 = {0.f};
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
uint8_t h = qh[l];
|
uint8_t h = qh[l];
|
||||||
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
|
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
||||||
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
|
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
||||||
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
|
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
||||||
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
|
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
||||||
|
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
|
||||||
|
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
|
||||||
|
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
||||||
|
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
||||||
}
|
}
|
||||||
const float dall = dh[0];
|
const float dall = dh[0];
|
||||||
const float dmin = dh[1];
|
const float dmin = dh[1];
|
||||||
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
|
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
||||||
|
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
||||||
|
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
||||||
|
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
||||||
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||||
|
|
||||||
q1 += step;
|
q1 += step;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue