mtl : faster mul_mat_q4_0_f32 kernel
This commit is contained in:
parent
33671460b0
commit
847bbfe9e6
2 changed files with 54 additions and 19 deletions
|
@ -555,7 +555,8 @@ int llama_mtl_eval(
|
|||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
int nth = 32;
|
||||
int nth0 = 32;
|
||||
int nth1 = 1;
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (src0t) {
|
||||
|
@ -564,14 +565,16 @@ int llama_mtl_eval(
|
|||
GGML_ASSERT(ne02 == 1);
|
||||
GGML_ASSERT(ne12 == 1);
|
||||
|
||||
nth = 4;
|
||||
nth0 = 8;
|
||||
nth1 = 4;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
|
||||
nth = 32;
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||
} break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
|
@ -595,11 +598,11 @@ int llama_mtl_eval(
|
|||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||
|
||||
if (src0t == GGML_TYPE_Q4_0) {
|
||||
[encoder setThreadgroupMemoryLength:16*nth*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 16, 1)];
|
||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
}
|
||||
} break;
|
||||
|
|
|
@ -266,31 +266,63 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||
device const float * y = (device const float *) src1 + r1*ne10;
|
||||
|
||||
const uint nth = tptg.x*tptg.y;
|
||||
const uint ith = 16*tpitg.x + tpitg.y;
|
||||
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
||||
|
||||
sum[ith] = 0.0f;
|
||||
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
device const uchar * x0p = (device const uchar *) (x + i)->qs;
|
||||
device const float * y0p = (device const float *) (y + i*qk);
|
||||
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
|
||||
device const float4 * y0p = (device const float4 *) (y + i*qk);
|
||||
|
||||
const float d = (float)((x + i)->d);
|
||||
|
||||
const uchar4 x0v = *(x0p + tpitg.y);
|
||||
const float4 y0v = *(y0p + tpitg.y + 0);
|
||||
const float4 y1v = *(y0p + tpitg.y + 4);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
//for (int j = 0; j < 16; ++j) {
|
||||
const int j = tpitg.y;
|
||||
{
|
||||
const uchar x0v = *(x0p + j);
|
||||
const int x0 = x0v[0] & 0x0F;
|
||||
const int x1 = x0v[0] >> 4;
|
||||
|
||||
const int x0 = x0v & 0x0F;
|
||||
const int x1 = x0v >> 4;
|
||||
|
||||
const float y0 = *(y0p + j);
|
||||
const float y1 = *(y0p + j + 16);
|
||||
const float y0 = y0v[0];
|
||||
const float y1 = y1v[0];
|
||||
|
||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||
}
|
||||
|
||||
sum[ith] += acc * (x + i)->d;
|
||||
{
|
||||
const int x0 = x0v[1] & 0x0F;
|
||||
const int x1 = x0v[1] >> 4;
|
||||
|
||||
const float y0 = y0v[1];
|
||||
const float y1 = y1v[1];
|
||||
|
||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||
}
|
||||
|
||||
{
|
||||
const int x0 = x0v[2] & 0x0F;
|
||||
const int x1 = x0v[2] >> 4;
|
||||
|
||||
const float y0 = y0v[2];
|
||||
const float y1 = y1v[2];
|
||||
|
||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||
}
|
||||
|
||||
{
|
||||
const int x0 = x0v[3] & 0x0F;
|
||||
const int x1 = x0v[3] >> 4;
|
||||
|
||||
const float y0 = y0v[3];
|
||||
const float y1 = y1v[3];
|
||||
|
||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||
}
|
||||
|
||||
sum[ith] += acc*d;
|
||||
}
|
||||
|
||||
// accumulate the sum from all threads in the threadgroup
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue