mtl : faster mul_mat_q4_0_f32 kernel
This commit is contained in:
parent
33671460b0
commit
847bbfe9e6
2 changed files with 54 additions and 19 deletions
|
@ -555,7 +555,8 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
int nth = 32;
|
int nth0 = 32;
|
||||||
|
int nth1 = 1;
|
||||||
|
|
||||||
// use custom matrix x vector kernel
|
// use custom matrix x vector kernel
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
|
@ -564,14 +565,16 @@ int llama_mtl_eval(
|
||||||
GGML_ASSERT(ne02 == 1);
|
GGML_ASSERT(ne02 == 1);
|
||||||
GGML_ASSERT(ne12 == 1);
|
GGML_ASSERT(ne12 == 1);
|
||||||
|
|
||||||
nth = 4;
|
nth0 = 8;
|
||||||
|
nth1 = 4;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == ne12);
|
GGML_ASSERT(ne02 == ne12);
|
||||||
|
|
||||||
nth = 32;
|
nth0 = 32;
|
||||||
|
nth1 = 1;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||||
} break;
|
} break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
|
@ -595,11 +598,11 @@ int llama_mtl_eval(
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0) {
|
if (src0t == GGML_TYPE_Q4_0) {
|
||||||
[encoder setThreadgroupMemoryLength:16*nth*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 16, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
|
@ -266,31 +266,63 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
device const float * y = (device const float *) src1 + r1*ne10;
|
device const float * y = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
const uint nth = tptg.x*tptg.y;
|
const uint nth = tptg.x*tptg.y;
|
||||||
const uint ith = 16*tpitg.x + tpitg.y;
|
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
||||||
|
|
||||||
sum[ith] = 0.0f;
|
sum[ith] = 0.0f;
|
||||||
|
|
||||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||||
device const uchar * x0p = (device const uchar *) (x + i)->qs;
|
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
|
||||||
device const float * y0p = (device const float *) (y + i*qk);
|
device const float4 * y0p = (device const float4 *) (y + i*qk);
|
||||||
|
|
||||||
|
const float d = (float)((x + i)->d);
|
||||||
|
|
||||||
|
const uchar4 x0v = *(x0p + tpitg.y);
|
||||||
|
const float4 y0v = *(y0p + tpitg.y + 0);
|
||||||
|
const float4 y1v = *(y0p + tpitg.y + 4);
|
||||||
|
|
||||||
float acc = 0.0f;
|
float acc = 0.0f;
|
||||||
|
|
||||||
//for (int j = 0; j < 16; ++j) {
|
|
||||||
const int j = tpitg.y;
|
|
||||||
{
|
{
|
||||||
const uchar x0v = *(x0p + j);
|
const int x0 = x0v[0] & 0x0F;
|
||||||
|
const int x1 = x0v[0] >> 4;
|
||||||
|
|
||||||
const int x0 = x0v & 0x0F;
|
const float y0 = y0v[0];
|
||||||
const int x1 = x0v >> 4;
|
const float y1 = y1v[0];
|
||||||
|
|
||||||
const float y0 = *(y0p + j);
|
|
||||||
const float y1 = *(y0p + j + 16);
|
|
||||||
|
|
||||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||||
}
|
}
|
||||||
|
|
||||||
sum[ith] += acc * (x + i)->d;
|
{
|
||||||
|
const int x0 = x0v[1] & 0x0F;
|
||||||
|
const int x1 = x0v[1] >> 4;
|
||||||
|
|
||||||
|
const float y0 = y0v[1];
|
||||||
|
const float y1 = y1v[1];
|
||||||
|
|
||||||
|
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const int x0 = x0v[2] & 0x0F;
|
||||||
|
const int x1 = x0v[2] >> 4;
|
||||||
|
|
||||||
|
const float y0 = y0v[2];
|
||||||
|
const float y1 = y1v[2];
|
||||||
|
|
||||||
|
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const int x0 = x0v[3] & 0x0F;
|
||||||
|
const int x1 = x0v[3] >> 4;
|
||||||
|
|
||||||
|
const float y0 = y0v[3];
|
||||||
|
const float y1 = y1v[3];
|
||||||
|
|
||||||
|
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||||
|
}
|
||||||
|
|
||||||
|
sum[ith] += acc*d;
|
||||||
}
|
}
|
||||||
|
|
||||||
// accumulate the sum from all threads in the threadgroup
|
// accumulate the sum from all threads in the threadgroup
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue