mtl : add f32 -> f32 cpy kernel

This commit is contained in:
Georgi Gerganov 2023-06-01 21:30:33 +03:00
parent a266c26de2
commit a0cc3de59a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 55 additions and 5 deletions

View file

@ -53,6 +53,9 @@ struct ggml_mtl_context {
id<MTLFunction> function_cpy_f32_f16; id<MTLFunction> function_cpy_f32_f16;
id<MTLComputePipelineState> pipeline_cpy_f32_f16; id<MTLComputePipelineState> pipeline_cpy_f32_f16;
id<MTLFunction> function_cpy_f32_f32;
id<MTLComputePipelineState> pipeline_cpy_f32_f32;
}; };
// MSL code // MSL code
@ -176,6 +179,10 @@ struct ggml_mtl_context * llama_mtl_init(
ctx->function_cpy_f32_f16 = [ctx->library newFunctionWithName:@"kernel_cpy_f32_f16"]; ctx->function_cpy_f32_f16 = [ctx->library newFunctionWithName:@"kernel_cpy_f32_f16"];
ctx->pipeline_cpy_f32_f16 = [ctx->device newComputePipelineStateWithFunction:ctx->function_cpy_f32_f16 error:nil]; ctx->pipeline_cpy_f32_f16 = [ctx->device newComputePipelineStateWithFunction:ctx->function_cpy_f32_f16 error:nil];
fprintf(stderr, "%s: loaded kernel_cpy_f32_f16: %p\n", __func__, (void *) ctx->pipeline_cpy_f32_f16); fprintf(stderr, "%s: loaded kernel_cpy_f32_f16: %p\n", __func__, (void *) ctx->pipeline_cpy_f32_f16);
ctx->function_cpy_f32_f32 = [ctx->library newFunctionWithName:@"kernel_cpy_f32_f32"];
ctx->pipeline_cpy_f32_f32 = [ctx->device newComputePipelineStateWithFunction:ctx->function_cpy_f32_f32 error:nil];
fprintf(stderr, "%s: loaded kernel_cpy_f32_f32: %p\n", __func__, (void *) ctx->pipeline_cpy_f32_f32);
} }
// MTLBuffer approach // MTLBuffer approach
@ -669,6 +676,7 @@ int llama_mtl_eval(
{ {
switch (dstt) { switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
}; };
} break; } break;

View file

@ -340,3 +340,45 @@ kernel void kernel_cpy_f32_f16(
dst_data[i00] = src[0]; dst_data[i00] = src[0];
} }
} }
kernel void kernel_cpy_f32_f32(
device const float * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}

View file

@ -1350,11 +1350,6 @@ static bool llama_eval_internal(
il*n_ctx*ggml_element_size(kv_self.v)*n_embd); il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
ggml_set_name(V, "V"); ggml_set_name(V, "V");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(V, "mtl-check");
}
#if 1 #if 1
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
ggml_set_name(KQV, "KQV"); ggml_set_name(KQV, "KQV");
@ -1376,6 +1371,11 @@ static bool llama_eval_internal(
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
ggml_set_name(cur, "KQV_merged_contiguous"); ggml_set_name(cur, "KQV_merged_contiguous");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(cur, "mtl-check");
}
// projection (no bias) // projection (no bias)
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].wo, model.layers[il].wo,