mtl : initial get_rows_q4_0 kernel

This commit is contained in:
Georgi Gerganov 2023-05-29 23:12:19 +03:00
parent 248a8c3379
commit a8fd9dc128
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 90 additions and 10 deletions

View file

@ -29,6 +29,9 @@ struct ggml_mtl_context {
id<MTLFunction> function_soft_max;
id<MTLComputePipelineState> pipeline_soft_max;
id<MTLFunction> function_get_rows_q4_0;
id<MTLComputePipelineState> pipeline_get_rows_q4_0;
};
// MSL code
@ -90,7 +93,7 @@ struct ggml_mtl_context * llama_mtl_init(
{
NSError * error = nil;
NSString * path = [[NSBundle mainBundle] pathForResource:@"../examples/mtl/mtl" ofType:@"metal"];
NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/mtl/mtl" ofType:@"metal"];
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
@ -107,10 +110,7 @@ struct ggml_mtl_context * llama_mtl_init(
// load kernels
{
const int k_digits = 123;
MTLFunctionConstantValues * constants = [MTLFunctionConstantValues new];
[constants setConstantValue:&k_digits type:MTLDataTypeInt withName:@"k_digits"];
ctx->function_add = [ctx->library newFunctionWithName:@"kernel_add"];
ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil];
@ -123,6 +123,10 @@ struct ggml_mtl_context * llama_mtl_init(
ctx->function_soft_max = [ctx->library newFunctionWithName:@"kernel_soft_max" constantValues:constants error:nil];
ctx->pipeline_soft_max = [ctx->device newComputePipelineStateWithFunction:ctx->function_soft_max error:nil];
fprintf(stderr, "%s: loaded kernel_soft_max: %p\n", __func__, (void *) ctx->pipeline_soft_max);
ctx->function_get_rows_q4_0 = [ctx->library newFunctionWithName:@"kernel_get_rows_q4_0"];
ctx->pipeline_get_rows_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_get_rows_q4_0 error:nil];
fprintf(stderr, "%s: loaded kernel_get_rows_q4_0: %p\n", __func__, (void *) ctx->pipeline_get_rows_q4_0);
}
// MTLBuffer approach
@ -315,6 +319,35 @@ int llama_mtl_eval(
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
} break;
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
switch (gf->nodes[i]->src0->type) {
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
default: {
// not implemented
fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), ggml_type_name(gf->nodes[i]->src0->type));
}
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&(gf->nodes[i]->src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(gf->nodes[i]->src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(gf->nodes[i]->nb[1]) length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(gf->nodes[i]->src1);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
GGML_ASSERT(false);

View file

@ -4,7 +4,35 @@ using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
constant int k_digits [[function_constant(0)]];
#define QK4_0 32
#define QR4_0 2
typedef struct {
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
const int qk = QK4_0;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const half d = x[i].d;
for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0x0F) - 8;
const int x1 = (x[i].qs[j] >> 4) - 8;
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
// TODO: not needed
constant int nsoftmax [[function_constant(0)]];
kernel void kernel_add(
device const float * src0,
@ -21,20 +49,39 @@ kernel void kernel_relu(
dst[gid] = max(0.0f, src[gid]);
}
// TODO: broken
kernel void kernel_soft_max(
device const float * src,
device float * dst,
uint gid[[thread_position_in_grid]]) {
device float * dst) {
float max = 0.0f;
for (int i = 0; i < k_digits; i++) {
for (int i = 0; i < nsoftmax; i++) {
max = MAX(max, src[i]);
}
float sum = 0.0f;
for (int i = 0; i < k_digits; i++) {
for (int i = 0; i < nsoftmax; i++) {
dst[i] = exp(src[i] - max);
sum += dst[i];
}
for (int i = 0; i < k_digits; i++) {
for (int i = 0; i < nsoftmax; i++) {
dst[i] /= sum;
}
}
// TODO: not tested
kernel void kernel_get_rows_q4_0(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb1,
uint gid[[thread_position_in_grid]]) {
device const block_q4_0 * src = (device const block_q4_0 *)src0;
const int i = gid;
const int r = ((device int32_t *) src1)[i];
dequantize_row_q4_0(
(device const block_q4_0 *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}