mtl : another mul_mat Q4 (still does not work)
This commit is contained in:
parent
96d005225f
commit
29bec00ba0
1 changed files with 47 additions and 28 deletions
|
@ -137,45 +137,64 @@ kernel void kernel_mul_mat_q4_0(
|
|||
const int qk = QK4_0;
|
||||
const int nb = ne00/qk;
|
||||
|
||||
device const block_q4_0 * x = (device const block_q4_0 *) (src0) + r0*nb;
|
||||
device const float * y = (device const float *) (src1) + r1*ne10;
|
||||
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
|
||||
device const float * y = (device const float *) src1 + r1*ne10;
|
||||
|
||||
threadgroup float sum[32]; // TODO: should be equal to threadgroup size
|
||||
sum[tpitg.x] = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; i += tptg.x) {
|
||||
device const uint4 * x0p = (device const uint4 *) (x + i)->qs;
|
||||
device const float4 * y0p = (device const float4 *) (y + i*qk);
|
||||
//device const uint4 * x0p = (device const uint4 *) (x + i)->qs;
|
||||
//device const float4 * y0p = (device const float4 *) (y + i*qk);
|
||||
|
||||
const uint4 x0 = *x0p;
|
||||
//const uint4 x0 = *x0p;
|
||||
|
||||
const uint4 x0l = (x0 & uint4(0x0F0F0F0F));
|
||||
const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4;
|
||||
//const uint4 x0l = (x0 & uint4(0x0F0F0F0F));
|
||||
//const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4;
|
||||
|
||||
thread const char * x0lsb = (thread const char *) &x0l;
|
||||
thread const char * x0hsb = (thread const char *) &x0h;
|
||||
//thread const char * x0lsb = (thread const char *) &x0l;
|
||||
//thread const char * x0hsb = (thread const char *) &x0h;
|
||||
|
||||
const float4 y00 = *(y0p + 0);
|
||||
const float4 y01 = *(y0p + 1);
|
||||
const float4 y02 = *(y0p + 2);
|
||||
const float4 y03 = *(y0p + 3);
|
||||
const float4 y04 = *(y0p + 4);
|
||||
const float4 y05 = *(y0p + 5);
|
||||
const float4 y06 = *(y0p + 6);
|
||||
const float4 y07 = *(y0p + 7);
|
||||
//const float4 y00 = *(y0p + 0);
|
||||
//const float4 y01 = *(y0p + 1);
|
||||
//const float4 y02 = *(y0p + 2);
|
||||
//const float4 y03 = *(y0p + 3);
|
||||
//const float4 y04 = *(y0p + 4);
|
||||
//const float4 y05 = *(y0p + 5);
|
||||
//const float4 y06 = *(y0p + 6);
|
||||
//const float4 y07 = *(y0p + 7);
|
||||
|
||||
const half d = (x + i)->d;
|
||||
//const half d = (x + i)->d;
|
||||
|
||||
sum[tpitg.x] += (
|
||||
(x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] +
|
||||
(x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] +
|
||||
(x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] +
|
||||
(x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] +
|
||||
(x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] +
|
||||
(x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] +
|
||||
(x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] +
|
||||
(x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3]
|
||||
) * d;
|
||||
//sum[tpitg.x] += (
|
||||
// (x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] +
|
||||
// (x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] +
|
||||
// (x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] +
|
||||
// (x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] +
|
||||
// (x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] +
|
||||
// (x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] +
|
||||
// (x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] +
|
||||
// (x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3]
|
||||
// ) * d;
|
||||
|
||||
device const uchar * x0p = (device const uchar *) (x + i)->qs;
|
||||
device const float * y0p = (device const float *) (y + i*qk);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
const uchar x0v = *(x0p + j);
|
||||
|
||||
const int x0 = x0v & 0x0F;
|
||||
const int x1 = x0v >> 4;
|
||||
|
||||
const float y0 = *(y0p + j);
|
||||
const float y1 = *(y0p + j + 16);
|
||||
|
||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||
}
|
||||
|
||||
sum[tpitg.x] += acc * (x + i)->d;
|
||||
}
|
||||
|
||||
// accumulate the sum from all threads in the threadgroup
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue