Another Q5_K speedup

This commit is contained in:
Iwan Kawrakow 2023-07-20 16:33:15 +03:00
parent 463f420710
commit 5f2e4bd8ba

View file

@ -1656,22 +1656,23 @@ kernel void kernel_mul_mat_q5_K_f32(
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb; device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
device const float * yy = (device const float *) src1 + r1*ne10; device const float * yy = (device const float *) src1 + r1*ne10;
float yl[8], yh[8];
float sumf[2]={0.f}; float sumf[2]={0.f};
const int step = sizeof(block_q5_K) * nb; const int step = sizeof(block_q5_K) * nb;
#if QK_K == 256 #if QK_K == 256
#
float yl[16], yh[16];
const uint16_t kmask1 = 0x3f3f; const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0; const uint16_t kmask3 = 0xc0c0;
const int tid = tiisg/2; const int tid = tiisg/4;
const int ix = tiisg%2; const int ix = tiisg%4;
const int im = tid/8; const int im = tid/4;
const int ir = tid%8; const int ir = tid%4;
const int n = 4; const int n = 8;
const int l0 = n*ir; const int l0 = n*ir;
const int q_offset = 32*im + l0; const int q_offset = 32*im + l0;
@ -1687,7 +1688,7 @@ kernel void kernel_mul_mat_q5_K_f32(
device const float * y1 = yy + ix*QK_K + y_offset; device const float * y1 = yy + ix*QK_K + y_offset;
for (int i = ix; i < nb; i += 2) { for (int i = ix; i < nb; i += 4) {
device const uint8_t * q1 = x[i].qs + q_offset; device const uint8_t * q1 = x[i].qs + q_offset;
device const uint8_t * qh = x[i].qh + l0; device const uint8_t * qh = x[i].qh + l0;
@ -1696,11 +1697,11 @@ kernel void kernel_mul_mat_q5_K_f32(
device const float * y2 = y1 + 128; device const float * y2 = y1 + 128;
float4 sumy = {0.f, 0.f, 0.f, 0.f}; float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 8; ++l) {
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
yl[l+4] = y1[l+32]; sumy[1] += yl[l+4]; yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
yh[l+4] = y2[l+32]; sumy[3] += yh[l+4]; yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
} }
for (int row = 0; row < 2; ++row) { for (int row = 0; row < 2; ++row) {
@ -1716,9 +1717,9 @@ kernel void kernel_mul_mat_q5_K_f32(
for (int l = 0; l < n; ++l) { for (int l = 0; l < n; ++l) {
uint8_t h = qh[l]; uint8_t h = qh[l];
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
acc[1] += yl[l+4] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
acc[3] += yh[l+4] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
} }
const float dall = dh[0]; const float dall = dh[0];
const float dmin = dh[1]; const float dmin = dh[1];
@ -1732,10 +1733,12 @@ kernel void kernel_mul_mat_q5_K_f32(
} }
y1 += 2 * QK_K; y1 += 4 * QK_K;
} }
#else #else
float yl[8], yh[8];
const int il = 4 * (tiisg/8); // 0, 4, 8, 12 const int il = 4 * (tiisg/8); // 0, 4, 8, 12
const int ix = tiisg%8; const int ix = tiisg%8;
const int im = il/8; // 0, 0, 1, 1 const int im = il/8; // 0, 0, 1, 1
@ -1747,10 +1750,10 @@ kernel void kernel_mul_mat_q5_K_f32(
float4 sumy = {0.f, 0.f, 0.f, 0.f}; float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
yl[l+0] = y[l+ 0]; //sumy[0] += yl[l+0]; yl[l+0] = y[l+ 0];
yl[l+4] = y[l+16]; //sumy[1] += yl[l+4]; yl[l+4] = y[l+16];
yh[l+0] = y[l+32]; //sumy[2] += yh[l+0]; yh[l+0] = y[l+32];
yh[l+4] = y[l+48]; //sumy[3] += yh[l+4]; yh[l+4] = y[l+48];
} }
device const half * dh = &x[i].d; device const half * dh = &x[i].d;