mtl : faster mul_mat_q4_0_f32 kernel

This commit is contained in:
Georgi Gerganov 2023-06-02 18:28:31 +03:00
parent 33671460b0
commit 847bbfe9e6
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 54 additions and 19 deletions

View file

@ -555,7 +555,8 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
int nth = 32; int nth0 = 32;
int nth1 = 1;
// use custom matrix x vector kernel // use custom matrix x vector kernel
switch (src0t) { switch (src0t) {
@ -564,14 +565,16 @@ int llama_mtl_eval(
GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1); GGML_ASSERT(ne12 == 1);
nth = 4; nth0 = 8;
nth1 = 4;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break; } break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne02 == ne12);
nth = 32; nth0 = 32;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break; } break;
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
@ -595,11 +598,11 @@ int llama_mtl_eval(
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
if (src0t == GGML_TYPE_Q4_0) { if (src0t == GGML_TYPE_Q4_0) {
[encoder setThreadgroupMemoryLength:16*nth*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 16, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else { } else {
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
} }
} break; } break;

View file

@ -266,31 +266,63 @@ kernel void kernel_mul_mat_q4_0_f32(
device const float * y = (device const float *) src1 + r1*ne10; device const float * y = (device const float *) src1 + r1*ne10;
const uint nth = tptg.x*tptg.y; const uint nth = tptg.x*tptg.y;
const uint ith = 16*tpitg.x + tpitg.y; const uint ith = tptg.y*tpitg.x + tpitg.y;
sum[ith] = 0.0f; sum[ith] = 0.0f;
for (int i = tpitg.x; i < nb; i += tptg.x) { for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uchar * x0p = (device const uchar *) (x + i)->qs; device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
device const float * y0p = (device const float *) (y + i*qk); device const float4 * y0p = (device const float4 *) (y + i*qk);
const float d = (float)((x + i)->d);
const uchar4 x0v = *(x0p + tpitg.y);
const float4 y0v = *(y0p + tpitg.y + 0);
const float4 y1v = *(y0p + tpitg.y + 4);
float acc = 0.0f; float acc = 0.0f;
//for (int j = 0; j < 16; ++j) {
const int j = tpitg.y;
{ {
const uchar x0v = *(x0p + j); const int x0 = x0v[0] & 0x0F;
const int x1 = x0v[0] >> 4;
const int x0 = x0v & 0x0F; const float y0 = y0v[0];
const int x1 = x0v >> 4; const float y1 = y1v[0];
const float y0 = *(y0p + j);
const float y1 = *(y0p + j + 16);
acc += (x0 - 8)*y0 + (x1 - 8)*y1; acc += (x0 - 8)*y0 + (x1 - 8)*y1;
} }
sum[ith] += acc * (x + i)->d; {
const int x0 = x0v[1] & 0x0F;
const int x1 = x0v[1] >> 4;
const float y0 = y0v[1];
const float y1 = y1v[1];
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
{
const int x0 = x0v[2] & 0x0F;
const int x1 = x0v[2] >> 4;
const float y0 = y0v[2];
const float y1 = y1v[2];
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
{
const int x0 = x0v[3] & 0x0F;
const int x1 = x0v[3] >> 4;
const float y0 = y0v[3];
const float y1 = y1v[3];
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
sum[ith] += acc*d;
} }
// accumulate the sum from all threads in the threadgroup // accumulate the sum from all threads in the threadgroup