From a8fd9dc12870c0c828c200f599e005ee1989148f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 May 2023 23:12:19 +0300 Subject: [PATCH] mtl : initial get_rows_q4_0 kernel --- examples/mtl/mtl.m | 41 ++++++++++++++++++++++++++--- examples/mtl/mtl.metal | 59 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 90 insertions(+), 10 deletions(-) diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 86e0b0c78..822edec20 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -29,6 +29,9 @@ struct ggml_mtl_context { id function_soft_max; id pipeline_soft_max; + + id function_get_rows_q4_0; + id pipeline_get_rows_q4_0; }; // MSL code @@ -90,7 +93,7 @@ struct ggml_mtl_context * llama_mtl_init( { NSError * error = nil; - NSString * path = [[NSBundle mainBundle] pathForResource:@"../examples/mtl/mtl" ofType:@"metal"]; + NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/mtl/mtl" ofType:@"metal"]; NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; if (error) { fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); @@ -107,10 +110,7 @@ struct ggml_mtl_context * llama_mtl_init( // load kernels { - const int k_digits = 123; - MTLFunctionConstantValues * constants = [MTLFunctionConstantValues new]; - [constants setConstantValue:&k_digits type:MTLDataTypeInt withName:@"k_digits"]; ctx->function_add = [ctx->library newFunctionWithName:@"kernel_add"]; ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil]; @@ -123,6 +123,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->function_soft_max = [ctx->library newFunctionWithName:@"kernel_soft_max" constantValues:constants error:nil]; ctx->pipeline_soft_max = [ctx->device newComputePipelineStateWithFunction:ctx->function_soft_max error:nil]; fprintf(stderr, "%s: loaded kernel_soft_max: %p\n", __func__, (void *) ctx->pipeline_soft_max); + + ctx->function_get_rows_q4_0 = [ctx->library newFunctionWithName:@"kernel_get_rows_q4_0"]; + ctx->pipeline_get_rows_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_get_rows_q4_0 error:nil]; + fprintf(stderr, "%s: loaded kernel_get_rows_q4_0: %p\n", __func__, (void *) ctx->pipeline_get_rows_q4_0); } // MTLBuffer approach @@ -315,6 +319,35 @@ int llama_mtl_eval( [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst]; } break; + case GGML_OP_GET_ROWS: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + switch (gf->nodes[i]->src0->type) { + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + default: { + // not implemented + fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), ggml_type_name(gf->nodes[i]->src0->type)); + } + } + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&(gf->nodes[i]->src0->ne[0]) length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&(gf->nodes[i]->src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&(gf->nodes[i]->nb[1]) length:sizeof(uint64_t) atIndex:5]; + + const int64_t n = ggml_nelements(gf->nodes[i]->src1); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); GGML_ASSERT(false); diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index e9597336c..33370fd6a 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -4,7 +4,35 @@ using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) -constant int k_digits [[function_constant(0)]]; +#define QK4_0 32 +#define QR4_0 2 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; + +static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) { + const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const half d = x[i].d; + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +// TODO: not needed +constant int nsoftmax [[function_constant(0)]]; kernel void kernel_add( device const float * src0, @@ -21,20 +49,39 @@ kernel void kernel_relu( dst[gid] = max(0.0f, src[gid]); } +// TODO: broken kernel void kernel_soft_max( device const float * src, - device float * dst, - uint gid[[thread_position_in_grid]]) { + device float * dst) { float max = 0.0f; - for (int i = 0; i < k_digits; i++) { + for (int i = 0; i < nsoftmax; i++) { max = MAX(max, src[i]); } float sum = 0.0f; - for (int i = 0; i < k_digits; i++) { + for (int i = 0; i < nsoftmax; i++) { dst[i] = exp(src[i] - max); sum += dst[i]; } - for (int i = 0; i < k_digits; i++) { + for (int i = 0; i < nsoftmax; i++) { dst[i] /= sum; } } + +// TODO: not tested +kernel void kernel_get_rows_q4_0( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint gid[[thread_position_in_grid]]) { + device const block_q4_0 * src = (device const block_q4_0 *)src0; + + const int i = gid; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q4_0( + (device const block_q4_0 *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +}