mtl : initial get_rows_q4_0 kernel
This commit is contained in:
parent
248a8c3379
commit
a8fd9dc128
2 changed files with 90 additions and 10 deletions
|
@ -29,6 +29,9 @@ struct ggml_mtl_context {
|
|||
|
||||
id<MTLFunction> function_soft_max;
|
||||
id<MTLComputePipelineState> pipeline_soft_max;
|
||||
|
||||
id<MTLFunction> function_get_rows_q4_0;
|
||||
id<MTLComputePipelineState> pipeline_get_rows_q4_0;
|
||||
};
|
||||
|
||||
// MSL code
|
||||
|
@ -90,7 +93,7 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
{
|
||||
NSError * error = nil;
|
||||
|
||||
NSString * path = [[NSBundle mainBundle] pathForResource:@"../examples/mtl/mtl" ofType:@"metal"];
|
||||
NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/mtl/mtl" ofType:@"metal"];
|
||||
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
||||
if (error) {
|
||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
|
@ -107,10 +110,7 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
|
||||
// load kernels
|
||||
{
|
||||
const int k_digits = 123;
|
||||
|
||||
MTLFunctionConstantValues * constants = [MTLFunctionConstantValues new];
|
||||
[constants setConstantValue:&k_digits type:MTLDataTypeInt withName:@"k_digits"];
|
||||
|
||||
ctx->function_add = [ctx->library newFunctionWithName:@"kernel_add"];
|
||||
ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil];
|
||||
|
@ -123,6 +123,10 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
ctx->function_soft_max = [ctx->library newFunctionWithName:@"kernel_soft_max" constantValues:constants error:nil];
|
||||
ctx->pipeline_soft_max = [ctx->device newComputePipelineStateWithFunction:ctx->function_soft_max error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_soft_max: %p\n", __func__, (void *) ctx->pipeline_soft_max);
|
||||
|
||||
ctx->function_get_rows_q4_0 = [ctx->library newFunctionWithName:@"kernel_get_rows_q4_0"];
|
||||
ctx->pipeline_get_rows_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_get_rows_q4_0 error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_get_rows_q4_0: %p\n", __func__, (void *) ctx->pipeline_get_rows_q4_0);
|
||||
}
|
||||
|
||||
// MTLBuffer approach
|
||||
|
@ -315,6 +319,35 @@ int llama_mtl_eval(
|
|||
|
||||
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||
|
||||
switch (gf->nodes[i]->src0->type) {
|
||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||
default: {
|
||||
// not implemented
|
||||
fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), ggml_type_name(gf->nodes[i]->src0->type));
|
||||
}
|
||||
}
|
||||
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&(gf->nodes[i]->src0->ne[0]) length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&(gf->nodes[i]->src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&(gf->nodes[i]->nb[1]) length:sizeof(uint64_t) atIndex:5];
|
||||
|
||||
const int64_t n = ggml_nelements(gf->nodes[i]->src1);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
GGML_ASSERT(false);
|
||||
|
|
|
@ -4,7 +4,35 @@ using namespace metal;
|
|||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
|
||||
constant int k_digits [[function_constant(0)]];
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
} block_q4_0;
|
||||
|
||||
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
||||
const int qk = QK4_0;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
||||
const int nb = k / qk;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const half d = x[i].d;
|
||||
|
||||
for (int j = 0; j < qk/2; ++j) {
|
||||
const int x0 = (x[i].qs[j] & 0x0F) - 8;
|
||||
const int x1 = (x[i].qs[j] >> 4) - 8;
|
||||
|
||||
y[i*qk + j + 0 ] = x0*d;
|
||||
y[i*qk + j + qk/2] = x1*d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: not needed
|
||||
constant int nsoftmax [[function_constant(0)]];
|
||||
|
||||
kernel void kernel_add(
|
||||
device const float * src0,
|
||||
|
@ -21,20 +49,39 @@ kernel void kernel_relu(
|
|||
dst[gid] = max(0.0f, src[gid]);
|
||||
}
|
||||
|
||||
// TODO: broken
|
||||
kernel void kernel_soft_max(
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
device float * dst) {
|
||||
float max = 0.0f;
|
||||
for (int i = 0; i < k_digits; i++) {
|
||||
for (int i = 0; i < nsoftmax; i++) {
|
||||
max = MAX(max, src[i]);
|
||||
}
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < k_digits; i++) {
|
||||
for (int i = 0; i < nsoftmax; i++) {
|
||||
dst[i] = exp(src[i] - max);
|
||||
sum += dst[i];
|
||||
}
|
||||
for (int i = 0; i < k_digits; i++) {
|
||||
for (int i = 0; i < nsoftmax; i++) {
|
||||
dst[i] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: not tested
|
||||
kernel void kernel_get_rows_q4_0(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb1,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
device const block_q4_0 * src = (device const block_q4_0 *)src0;
|
||||
|
||||
const int i = gid;
|
||||
const int r = ((device int32_t *) src1)[i];
|
||||
|
||||
dequantize_row_q4_0(
|
||||
(device const block_q4_0 *) ((device char *) src0 + r*nb01),
|
||||
(device float *) ((device char *) dst + i*nb1), ne00);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue