mtl : confirm f16 x f32 attention mul mat

This commit is contained in:
Georgi Gerganov 2023-06-01 19:45:36 +03:00
parent 948fcfde7e
commit 51efb59437
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 106 additions and 71 deletions

View file

@ -267,6 +267,7 @@ int llama_mtl_eval(
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
{ {
// noop // noop
} break; } break;
@ -344,81 +345,101 @@ int llama_mtl_eval(
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
if (gf->nodes[i]->src0->type == GGML_TYPE_F32) { {
// for F32 x F32 we use MPS
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
}
// use MPSMatrixMultiplication
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const int64_t ncols0 = gf->nodes[i]->src0->ne[0];
const int64_t nrows0 = gf->nodes[i]->src0->ne[1];
const int64_t ncols1 = gf->nodes[i]->src1->ne[0];
const int64_t nrows1 = gf->nodes[i]->src1->ne[1];
const int64_t ncols2 = gf->nodes[i]->ne[0];
const int64_t nrows2 = gf->nodes[i]->ne[1];
GGML_ASSERT(ncols0 == ncols1);
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
matrixDescriptorWithRows:nrows0 columns:ncols0 rowBytes:gf->nodes[i]->src0->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
matrixDescriptorWithRows:nrows1 columns:ncols1 rowBytes:gf->nodes[i]->src1->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixDescriptor * desc2 = [MPSMatrixDescriptor
matrixDescriptorWithRows:nrows2 columns:ncols2 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0 descriptor:desc0];
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1 descriptor:desc1];
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst descriptor:desc2];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] initWithDevice:ctx->device
transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0];
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
} else {
// for Q4 x F32 we use custom kernel
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
GGML_ASSERT(gf->nodes[i]->src0->ne[2] == 1);
GGML_ASSERT(gf->nodes[i]->src1->ne[2] == 1);
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const int64_t ne00 = gf->nodes[i]->src0->ne[0]; const int64_t ne00 = gf->nodes[i]->src0->ne[0];
const int64_t ne01 = gf->nodes[i]->src0->ne[1]; const int64_t ne01 = gf->nodes[i]->src0->ne[1];
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
//const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
//const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
const int64_t ne10 = gf->nodes[i]->src1->ne[0]; const int64_t ne10 = gf->nodes[i]->src1->ne[0];
const int64_t ne11 = gf->nodes[i]->src1->ne[1]; const int64_t ne11 = gf->nodes[i]->src1->ne[1];
const int64_t ne12 = gf->nodes[i]->src1->ne[2];
//const uint64_t nb10 = gf->nodes[i]->src1->nb[0];
//const uint64_t nb11 = gf->nodes[i]->src1->nb[1];
const uint64_t nb12 = gf->nodes[i]->src1->nb[2];
const int64_t ne0 = gf->nodes[i]->ne[0]; const int64_t ne0 = gf->nodes[i]->ne[0];
const int64_t ne1 = gf->nodes[i]->ne[1]; const int64_t ne1 = gf->nodes[i]->ne[1];
const int64_t ne2 = gf->nodes[i]->ne[2];
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0]; //const uint64_t nb0 = gf->nodes[i]->nb[0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; //const uint64_t nb1 = gf->nodes[i]->nb[1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; const uint64_t nb2 = gf->nodes[i]->nb[2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:5];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ne00, ne01, ne10, ne11, ne0, ne1); const enum ggml_type src0t = gf->nodes[i]->src0->type;
const enum ggml_type src1t = gf->nodes[i]->src1->type;
const enum ggml_type dstt = gf->nodes[i]->type;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; printf("mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02);
printf("mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12);
printf("mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
printf("mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne02 == ne12);
if (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) {
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
}
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
// for F32 x F32 we use MPS
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:gf->nodes[i]->src0->nb[1] dataType:src0dt];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:gf->nodes[i]->src1->nb[1] dataType:src1dt];
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
initWithDevice:ctx->device transposeLeft:false transposeRight:true
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
for (int64_t i02 = 0; i02 < ne02; ++i02) {
size_t offs_src0_cur = offs_src0 + i02*nb02;
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
// for Q4 x F32 we use custom kernel
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:5];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
}
} break; } break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {

8
ggml.c
View file

@ -14613,7 +14613,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou
const int64_t * ne = tensor->ne; const int64_t * ne = tensor->ne;
const size_t * nb = tensor->nb; const size_t * nb = tensor->nb;
fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %16s\n", fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %32s\n",
ggml_type_name(tensor->type), ggml_type_name(tensor->type),
ggml_op_name (tensor->op), ggml_op_name (tensor->op),
tensor->n_dims, tensor->n_dims,
@ -14627,7 +14627,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
const int64_t * ne = tensor->ne; const int64_t * ne = tensor->ne;
const size_t * nb = tensor->nb; const size_t * nb = tensor->nb;
fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %16s\n", fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %32s\n",
arg, arg,
ggml_type_name(tensor->type), ggml_type_name(tensor->type),
ggml_op_name (tensor->op), ggml_op_name (tensor->op),
@ -15067,6 +15067,10 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
{ {
tensor = ggml_transpose(*ctx_eval, args[0]); tensor = ggml_transpose(*ctx_eval, args[0]);
} break; } break;
case GGML_OP_PERMUTE:
{
tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
} break;
default: default:
{ {
tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne); tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);

View file

@ -1289,16 +1289,22 @@ static bool llama_eval_internal(
( n_ctx)*ggml_element_size(kv_self.v), ( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
struct ggml_tensor * t = ggml_cpy(ctx0, Kcur, k); //struct ggml_tensor * t = ggml_cpy(ctx0, Vcur, v);
// TODO: TMP !!!! //// TODO: TMP !!!!
if (il == 0) { //if (il == 0) {
ggml_set_name(t, "mtl-check"); // ggml_set_name(t, "mtl-check");
} //}
// important: storing RoPE-ed version of K in the KV cache! // important: storing RoPE-ed version of K in the KV cache!
//ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, t);
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
//ggml_build_forward_expand(&gf, t);
// TODO: TMP !!!!!!!!!!
if (il == 0) {
ggml_build_forward_expand(&gf_export, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf_export, ggml_cpy(ctx0, Vcur, v));
}
} }
struct ggml_tensor * Q = struct ggml_tensor * Q =
@ -1318,6 +1324,10 @@ static bool llama_eval_internal(
// K * Q // K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
ggml_set_name(KQ, "KQ"); ggml_set_name(KQ, "KQ");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(KQ, "mtl-check");
}
// KQ_scaled = KQ / sqrt(n_embd/n_head) // KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));