mtl : add f16 mat x f32 vec multiplication kernel
This commit is contained in:
parent
f0196a7e7a
commit
e55f7b0bdb
2 changed files with 68 additions and 44 deletions
|
@ -52,8 +52,11 @@ struct ggml_mtl_context {
|
|||
id<MTLFunction> function_rms_norm;
|
||||
id<MTLComputePipelineState> pipeline_rms_norm;
|
||||
|
||||
id<MTLFunction> function_mul_mat_q4_0;
|
||||
id<MTLComputePipelineState> pipeline_mul_mat_q4_0;
|
||||
id<MTLFunction> function_mul_mat_q4_0_f32;
|
||||
id<MTLComputePipelineState> pipeline_mul_mat_q4_0_f32;
|
||||
|
||||
id<MTLFunction> function_mul_mat_f16_f32;
|
||||
id<MTLComputePipelineState> pipeline_mul_mat_f16_f32;
|
||||
|
||||
id<MTLFunction> function_rope;
|
||||
id<MTLComputePipelineState> pipeline_rope;
|
||||
|
@ -183,9 +186,13 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm);
|
||||
|
||||
ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
|
||||
ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);
|
||||
ctx->function_mul_mat_q4_0_f32 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0_f32"];
|
||||
ctx->pipeline_mul_mat_q4_0_f32 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0_f32 error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0_f32: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0_f32);
|
||||
|
||||
ctx->function_mul_mat_f16_f32 = [ctx->library newFunctionWithName:@"kernel_mul_mat_f16_f32"];
|
||||
ctx->pipeline_mul_mat_f16_f32 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_f16_f32 error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_mul_mat_f16_f32: %p\n", __func__, (void *) ctx->pipeline_mul_mat_f16_f32);
|
||||
|
||||
ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"];
|
||||
ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil];
|
||||
|
@ -493,6 +500,8 @@ int llama_mtl_eval(
|
|||
//const uint64_t nb1 = gf->nodes[i]->nb[1];
|
||||
const uint64_t nb2 = gf->nodes[i]->nb[2];
|
||||
|
||||
const int nth = 16;
|
||||
|
||||
const enum ggml_type src0t = gf->nodes[i]->src0->type;
|
||||
const enum ggml_type src1t = gf->nodes[i]->src1->type;
|
||||
const enum ggml_type dstt = gf->nodes[i]->type;
|
||||
|
@ -505,7 +514,7 @@ int llama_mtl_eval(
|
|||
GGML_ASSERT(ne00 == ne10);
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
|
||||
if (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) {
|
||||
if ((src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
|
||||
if (encoder != nil) {
|
||||
[encoder endEncoding];
|
||||
encoder = nil;
|
||||
|
@ -528,6 +537,8 @@ int llama_mtl_eval(
|
|||
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
||||
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
|
||||
|
||||
// we need to do ne02 multiplications
|
||||
// TODO: is there a way to do this in parallel - currently very slow ..
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||
size_t offs_src0_cur = offs_src0 + i02*nb02;
|
||||
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
||||
|
@ -544,8 +555,13 @@ int llama_mtl_eval(
|
|||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
// for Q4 x F32 we use custom kernel
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
|
||||
// use custom matrix x vector kernel
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; break;
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
};
|
||||
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
@ -555,9 +571,9 @@ int llama_mtl_eval(
|
|||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
|
|
@ -241,7 +241,7 @@ kernel void kernel_rms_norm(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_q4_0(
|
||||
kernel void kernel_mul_mat_q4_0_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
|
@ -268,39 +268,6 @@ kernel void kernel_mul_mat_q4_0(
|
|||
sum[tpitg.x] = 0.0f;
|
||||
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
//device const uint4 * x0p = (device const uint4 *) (x + i)->qs;
|
||||
//device const float4 * y0p = (device const float4 *) (y + i*qk);
|
||||
|
||||
//const uint4 x0 = *x0p;
|
||||
|
||||
//const uint4 x0l = (x0 & uint4(0x0F0F0F0F));
|
||||
//const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4;
|
||||
|
||||
//thread const char * x0lsb = (thread const char *) &x0l;
|
||||
//thread const char * x0hsb = (thread const char *) &x0h;
|
||||
|
||||
//const float4 y00 = *(y0p + 0);
|
||||
//const float4 y01 = *(y0p + 1);
|
||||
//const float4 y02 = *(y0p + 2);
|
||||
//const float4 y03 = *(y0p + 3);
|
||||
//const float4 y04 = *(y0p + 4);
|
||||
//const float4 y05 = *(y0p + 5);
|
||||
//const float4 y06 = *(y0p + 6);
|
||||
//const float4 y07 = *(y0p + 7);
|
||||
|
||||
//const half d = (x + i)->d;
|
||||
|
||||
//sum[tpitg.x] += (
|
||||
// (x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] +
|
||||
// (x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] +
|
||||
// (x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] +
|
||||
// (x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] +
|
||||
// (x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] +
|
||||
// (x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] +
|
||||
// (x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] +
|
||||
// (x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3]
|
||||
// ) * d;
|
||||
|
||||
device const uchar * x0p = (device const uchar *) (x + i)->qs;
|
||||
device const float * y0p = (device const float *) (y + i*qk);
|
||||
|
||||
|
@ -335,6 +302,47 @@ kernel void kernel_mul_mat_q4_0(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_f16_f32(
|
||||
device const half * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
threadgroup float * sum [[threadgroup(0)]],
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint2 tpig[[thread_position_in_grid]],
|
||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||
uint2 tptg[[threads_per_threadgroup]]) {
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
|
||||
device const half * x = src0 + r0*ne00;
|
||||
device const float * y = src1 + r1*ne10;
|
||||
|
||||
sum[tpitg.x] = 0.0f;
|
||||
|
||||
for (int i = tpitg.x; i < ne00; i += tptg.x) {
|
||||
sum[tpitg.x] += (float) x[i] * (float) y[i];
|
||||
}
|
||||
|
||||
// accumulate the sum from all threads in the threadgroup
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint i = tptg.x/2; i > 0; i /= 2) {
|
||||
if (tpitg.x < i) {
|
||||
sum[tpitg.x] += sum[tpitg.x + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst[r1*ne0 + r0] = sum[0];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope(
|
||||
device const void * src0,
|
||||
device float * dst,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue