mtl : another mul_mat Q4 (still does not work)

This commit is contained in:
Georgi Gerganov 2023-05-30 22:31:07 +03:00
parent 96d005225f
commit 29bec00ba0
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -137,45 +137,64 @@ kernel void kernel_mul_mat_q4_0(
const int qk = QK4_0;
const int nb = ne00/qk;
device const block_q4_0 * x = (device const block_q4_0 *) (src0) + r0*nb;
device const float * y = (device const float *) (src1) + r1*ne10;
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
device const float * y = (device const float *) src1 + r1*ne10;
threadgroup float sum[32]; // TODO: should be equal to threadgroup size
sum[tpitg.x] = 0.0f;
for (int i = 0; i < nb; i += tptg.x) {
device const uint4 * x0p = (device const uint4 *) (x + i)->qs;
device const float4 * y0p = (device const float4 *) (y + i*qk);
//device const uint4 * x0p = (device const uint4 *) (x + i)->qs;
//device const float4 * y0p = (device const float4 *) (y + i*qk);
const uint4 x0 = *x0p;
//const uint4 x0 = *x0p;
const uint4 x0l = (x0 & uint4(0x0F0F0F0F));
const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4;
//const uint4 x0l = (x0 & uint4(0x0F0F0F0F));
//const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4;
thread const char * x0lsb = (thread const char *) &x0l;
thread const char * x0hsb = (thread const char *) &x0h;
//thread const char * x0lsb = (thread const char *) &x0l;
//thread const char * x0hsb = (thread const char *) &x0h;
const float4 y00 = *(y0p + 0);
const float4 y01 = *(y0p + 1);
const float4 y02 = *(y0p + 2);
const float4 y03 = *(y0p + 3);
const float4 y04 = *(y0p + 4);
const float4 y05 = *(y0p + 5);
const float4 y06 = *(y0p + 6);
const float4 y07 = *(y0p + 7);
//const float4 y00 = *(y0p + 0);
//const float4 y01 = *(y0p + 1);
//const float4 y02 = *(y0p + 2);
//const float4 y03 = *(y0p + 3);
//const float4 y04 = *(y0p + 4);
//const float4 y05 = *(y0p + 5);
//const float4 y06 = *(y0p + 6);
//const float4 y07 = *(y0p + 7);
const half d = (x + i)->d;
//const half d = (x + i)->d;
sum[tpitg.x] += (
(x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] +
(x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] +
(x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] +
(x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] +
(x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] +
(x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] +
(x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] +
(x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3]
) * d;
//sum[tpitg.x] += (
// (x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] +
// (x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] +
// (x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] +
// (x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] +
// (x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] +
// (x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] +
// (x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] +
// (x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3]
// ) * d;
device const uchar * x0p = (device const uchar *) (x + i)->qs;
device const float * y0p = (device const float *) (y + i*qk);
float acc = 0.0f;
for (int j = 0; j < 16; ++j) {
const uchar x0v = *(x0p + j);
const int x0 = x0v & 0x0F;
const int x1 = x0v >> 4;
const float y0 = *(y0p + j);
const float y1 = *(y0p + j + 16);
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
sum[tpitg.x] += acc * (x + i)->d;
}
// accumulate the sum from all threads in the threadgroup