mtl : another mul_mat Q4 (still does not work)

This commit is contained in:
Georgi Gerganov 2023-05-30 22:31:07 +03:00
parent 96d005225f
commit 29bec00ba0
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -137,45 +137,64 @@ kernel void kernel_mul_mat_q4_0(
const int qk = QK4_0; const int qk = QK4_0;
const int nb = ne00/qk; const int nb = ne00/qk;
device const block_q4_0 * x = (device const block_q4_0 *) (src0) + r0*nb; device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
device const float * y = (device const float *) (src1) + r1*ne10; device const float * y = (device const float *) src1 + r1*ne10;
threadgroup float sum[32]; // TODO: should be equal to threadgroup size threadgroup float sum[32]; // TODO: should be equal to threadgroup size
sum[tpitg.x] = 0.0f; sum[tpitg.x] = 0.0f;
for (int i = 0; i < nb; i += tptg.x) { for (int i = 0; i < nb; i += tptg.x) {
device const uint4 * x0p = (device const uint4 *) (x + i)->qs; //device const uint4 * x0p = (device const uint4 *) (x + i)->qs;
device const float4 * y0p = (device const float4 *) (y + i*qk); //device const float4 * y0p = (device const float4 *) (y + i*qk);
const uint4 x0 = *x0p; //const uint4 x0 = *x0p;
const uint4 x0l = (x0 & uint4(0x0F0F0F0F)); //const uint4 x0l = (x0 & uint4(0x0F0F0F0F));
const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4; //const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4;
thread const char * x0lsb = (thread const char *) &x0l; //thread const char * x0lsb = (thread const char *) &x0l;
thread const char * x0hsb = (thread const char *) &x0h; //thread const char * x0hsb = (thread const char *) &x0h;
const float4 y00 = *(y0p + 0); //const float4 y00 = *(y0p + 0);
const float4 y01 = *(y0p + 1); //const float4 y01 = *(y0p + 1);
const float4 y02 = *(y0p + 2); //const float4 y02 = *(y0p + 2);
const float4 y03 = *(y0p + 3); //const float4 y03 = *(y0p + 3);
const float4 y04 = *(y0p + 4); //const float4 y04 = *(y0p + 4);
const float4 y05 = *(y0p + 5); //const float4 y05 = *(y0p + 5);
const float4 y06 = *(y0p + 6); //const float4 y06 = *(y0p + 6);
const float4 y07 = *(y0p + 7); //const float4 y07 = *(y0p + 7);
const half d = (x + i)->d; //const half d = (x + i)->d;
sum[tpitg.x] += ( //sum[tpitg.x] += (
(x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] + // (x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] +
(x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] + // (x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] +
(x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] + // (x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] +
(x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] + // (x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] +
(x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] + // (x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] +
(x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] + // (x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] +
(x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] + // (x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] +
(x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3] // (x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3]
) * d; // ) * d;
device const uchar * x0p = (device const uchar *) (x + i)->qs;
device const float * y0p = (device const float *) (y + i*qk);
float acc = 0.0f;
for (int j = 0; j < 16; ++j) {
const uchar x0v = *(x0p + j);
const int x0 = x0v & 0x0F;
const int x1 = x0v >> 4;
const float y0 = *(y0p + j);
const float y1 = *(y0p + j + 16);
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
sum[tpitg.x] += acc * (x + i)->d;
} }
// accumulate the sum from all threads in the threadgroup // accumulate the sum from all threads in the threadgroup