mtl : confirm f16 x f32 attention mul mat
This commit is contained in:
parent
948fcfde7e
commit
51efb59437
3 changed files with 106 additions and 71 deletions
|
@ -267,6 +267,7 @@ int llama_mtl_eval(
|
|||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_PERMUTE:
|
||||
{
|
||||
// noop
|
||||
} break;
|
||||
|
@ -344,66 +345,87 @@ int llama_mtl_eval(
|
|||
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (gf->nodes[i]->src0->type == GGML_TYPE_F32) {
|
||||
// for F32 x F32 we use MPS
|
||||
|
||||
if (encoder != nil) {
|
||||
[encoder endEncoding];
|
||||
encoder = nil;
|
||||
}
|
||||
|
||||
// use MPSMatrixMultiplication
|
||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||
|
||||
const int64_t ncols0 = gf->nodes[i]->src0->ne[0];
|
||||
const int64_t nrows0 = gf->nodes[i]->src0->ne[1];
|
||||
|
||||
const int64_t ncols1 = gf->nodes[i]->src1->ne[0];
|
||||
const int64_t nrows1 = gf->nodes[i]->src1->ne[1];
|
||||
|
||||
const int64_t ncols2 = gf->nodes[i]->ne[0];
|
||||
const int64_t nrows2 = gf->nodes[i]->ne[1];
|
||||
|
||||
GGML_ASSERT(ncols0 == ncols1);
|
||||
|
||||
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
|
||||
matrixDescriptorWithRows:nrows0 columns:ncols0 rowBytes:gf->nodes[i]->src0->nb[1] dataType:MPSDataTypeFloat32];
|
||||
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
|
||||
matrixDescriptorWithRows:nrows1 columns:ncols1 rowBytes:gf->nodes[i]->src1->nb[1] dataType:MPSDataTypeFloat32];
|
||||
MPSMatrixDescriptor * desc2 = [MPSMatrixDescriptor
|
||||
matrixDescriptorWithRows:nrows2 columns:ncols2 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
|
||||
|
||||
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0 descriptor:desc0];
|
||||
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1 descriptor:desc1];
|
||||
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst descriptor:desc2];
|
||||
|
||||
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] initWithDevice:ctx->device
|
||||
transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0];
|
||||
|
||||
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
||||
} else {
|
||||
// for Q4 x F32 we use custom kernel
|
||||
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
GGML_ASSERT(gf->nodes[i]->src0->ne[2] == 1);
|
||||
GGML_ASSERT(gf->nodes[i]->src1->ne[2] == 1);
|
||||
|
||||
{
|
||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||
|
||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
||||
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
||||
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
||||
|
||||
//const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
|
||||
//const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
||||
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
|
||||
|
||||
const int64_t ne10 = gf->nodes[i]->src1->ne[0];
|
||||
const int64_t ne11 = gf->nodes[i]->src1->ne[1];
|
||||
const int64_t ne12 = gf->nodes[i]->src1->ne[2];
|
||||
|
||||
//const uint64_t nb10 = gf->nodes[i]->src1->nb[0];
|
||||
//const uint64_t nb11 = gf->nodes[i]->src1->nb[1];
|
||||
const uint64_t nb12 = gf->nodes[i]->src1->nb[2];
|
||||
|
||||
const int64_t ne0 = gf->nodes[i]->ne[0];
|
||||
const int64_t ne1 = gf->nodes[i]->ne[1];
|
||||
const int64_t ne2 = gf->nodes[i]->ne[2];
|
||||
|
||||
//const uint64_t nb0 = gf->nodes[i]->nb[0];
|
||||
//const uint64_t nb1 = gf->nodes[i]->nb[1];
|
||||
const uint64_t nb2 = gf->nodes[i]->nb[2];
|
||||
|
||||
const enum ggml_type src0t = gf->nodes[i]->src0->type;
|
||||
const enum ggml_type src1t = gf->nodes[i]->src1->type;
|
||||
const enum ggml_type dstt = gf->nodes[i]->type;
|
||||
|
||||
printf("mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02);
|
||||
printf("mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12);
|
||||
printf("mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
|
||||
printf("mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
|
||||
if (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) {
|
||||
if (encoder != nil) {
|
||||
[encoder endEncoding];
|
||||
encoder = nil;
|
||||
}
|
||||
|
||||
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
|
||||
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
|
||||
|
||||
// for F32 x F32 we use MPS
|
||||
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
|
||||
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:gf->nodes[i]->src0->nb[1] dataType:src0dt];
|
||||
|
||||
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
|
||||
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:gf->nodes[i]->src1->nb[1] dataType:src1dt];
|
||||
|
||||
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
|
||||
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
|
||||
|
||||
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
|
||||
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
||||
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
|
||||
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||
size_t offs_src0_cur = offs_src0 + i02*nb02;
|
||||
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
||||
size_t offs_dst_cur = offs_dst + i02*nb2;
|
||||
|
||||
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
|
||||
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
|
||||
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
|
||||
|
||||
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
||||
}
|
||||
} else {
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
// for Q4 x F32 we use custom kernel
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
|
@ -416,9 +438,8 @@ int llama_mtl_eval(
|
|||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ne00, ne01, ne10, ne11, ne0, ne1);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
|
|
8
ggml.c
8
ggml.c
|
@ -14613,7 +14613,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou
|
|||
const int64_t * ne = tensor->ne;
|
||||
const size_t * nb = tensor->nb;
|
||||
|
||||
fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %16s\n",
|
||||
fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %32s\n",
|
||||
ggml_type_name(tensor->type),
|
||||
ggml_op_name (tensor->op),
|
||||
tensor->n_dims,
|
||||
|
@ -14627,7 +14627,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
|
|||
const int64_t * ne = tensor->ne;
|
||||
const size_t * nb = tensor->nb;
|
||||
|
||||
fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %16s\n",
|
||||
fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %32s\n",
|
||||
arg,
|
||||
ggml_type_name(tensor->type),
|
||||
ggml_op_name (tensor->op),
|
||||
|
@ -15067,6 +15067,10 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|||
{
|
||||
tensor = ggml_transpose(*ctx_eval, args[0]);
|
||||
} break;
|
||||
case GGML_OP_PERMUTE:
|
||||
{
|
||||
tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
|
||||
|
|
24
llama.cpp
24
llama.cpp
|
@ -1289,16 +1289,22 @@ static bool llama_eval_internal(
|
|||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||
|
||||
struct ggml_tensor * t = ggml_cpy(ctx0, Kcur, k);
|
||||
// TODO: TMP !!!!
|
||||
if (il == 0) {
|
||||
ggml_set_name(t, "mtl-check");
|
||||
}
|
||||
//struct ggml_tensor * t = ggml_cpy(ctx0, Vcur, v);
|
||||
//// TODO: TMP !!!!
|
||||
//if (il == 0) {
|
||||
// ggml_set_name(t, "mtl-check");
|
||||
//}
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
//ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, t);
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
//ggml_build_forward_expand(&gf, t);
|
||||
|
||||
// TODO: TMP !!!!!!!!!!
|
||||
if (il == 0) {
|
||||
ggml_build_forward_expand(&gf_export, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf_export, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
|
@ -1318,6 +1324,10 @@ static bool llama_eval_internal(
|
|||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
ggml_set_name(KQ, "KQ");
|
||||
// TODO: TMP !!!!
|
||||
if (il == 0) {
|
||||
ggml_set_name(KQ, "mtl-check");
|
||||
}
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue