mtl : confirm f16 x f32 attention mul mat
This commit is contained in:
parent
948fcfde7e
commit
51efb59437
3 changed files with 106 additions and 71 deletions
|
@ -267,6 +267,7 @@ int llama_mtl_eval(
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
|
case GGML_OP_PERMUTE:
|
||||||
{
|
{
|
||||||
// noop
|
// noop
|
||||||
} break;
|
} break;
|
||||||
|
@ -344,66 +345,87 @@ int llama_mtl_eval(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
if (gf->nodes[i]->src0->type == GGML_TYPE_F32) {
|
{
|
||||||
// for F32 x F32 we use MPS
|
|
||||||
|
|
||||||
if (encoder != nil) {
|
|
||||||
[encoder endEncoding];
|
|
||||||
encoder = nil;
|
|
||||||
}
|
|
||||||
|
|
||||||
// use MPSMatrixMultiplication
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
const int64_t ncols0 = gf->nodes[i]->src0->ne[0];
|
|
||||||
const int64_t nrows0 = gf->nodes[i]->src0->ne[1];
|
|
||||||
|
|
||||||
const int64_t ncols1 = gf->nodes[i]->src1->ne[0];
|
|
||||||
const int64_t nrows1 = gf->nodes[i]->src1->ne[1];
|
|
||||||
|
|
||||||
const int64_t ncols2 = gf->nodes[i]->ne[0];
|
|
||||||
const int64_t nrows2 = gf->nodes[i]->ne[1];
|
|
||||||
|
|
||||||
GGML_ASSERT(ncols0 == ncols1);
|
|
||||||
|
|
||||||
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
|
|
||||||
matrixDescriptorWithRows:nrows0 columns:ncols0 rowBytes:gf->nodes[i]->src0->nb[1] dataType:MPSDataTypeFloat32];
|
|
||||||
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
|
|
||||||
matrixDescriptorWithRows:nrows1 columns:ncols1 rowBytes:gf->nodes[i]->src1->nb[1] dataType:MPSDataTypeFloat32];
|
|
||||||
MPSMatrixDescriptor * desc2 = [MPSMatrixDescriptor
|
|
||||||
matrixDescriptorWithRows:nrows2 columns:ncols2 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
|
|
||||||
|
|
||||||
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0 descriptor:desc0];
|
|
||||||
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1 descriptor:desc1];
|
|
||||||
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst descriptor:desc2];
|
|
||||||
|
|
||||||
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] initWithDevice:ctx->device
|
|
||||||
transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0];
|
|
||||||
|
|
||||||
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
|
||||||
} else {
|
|
||||||
// for Q4 x F32 we use custom kernel
|
|
||||||
|
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(gf->nodes[i]->src0->ne[2] == 1);
|
|
||||||
GGML_ASSERT(gf->nodes[i]->src1->ne[2] == 1);
|
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||||
|
|
||||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
||||||
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
||||||
|
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
||||||
|
|
||||||
|
//const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
|
||||||
|
//const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
||||||
|
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
|
||||||
|
|
||||||
const int64_t ne10 = gf->nodes[i]->src1->ne[0];
|
const int64_t ne10 = gf->nodes[i]->src1->ne[0];
|
||||||
const int64_t ne11 = gf->nodes[i]->src1->ne[1];
|
const int64_t ne11 = gf->nodes[i]->src1->ne[1];
|
||||||
|
const int64_t ne12 = gf->nodes[i]->src1->ne[2];
|
||||||
|
|
||||||
|
//const uint64_t nb10 = gf->nodes[i]->src1->nb[0];
|
||||||
|
//const uint64_t nb11 = gf->nodes[i]->src1->nb[1];
|
||||||
|
const uint64_t nb12 = gf->nodes[i]->src1->nb[2];
|
||||||
|
|
||||||
const int64_t ne0 = gf->nodes[i]->ne[0];
|
const int64_t ne0 = gf->nodes[i]->ne[0];
|
||||||
const int64_t ne1 = gf->nodes[i]->ne[1];
|
const int64_t ne1 = gf->nodes[i]->ne[1];
|
||||||
|
const int64_t ne2 = gf->nodes[i]->ne[2];
|
||||||
|
|
||||||
|
//const uint64_t nb0 = gf->nodes[i]->nb[0];
|
||||||
|
//const uint64_t nb1 = gf->nodes[i]->nb[1];
|
||||||
|
const uint64_t nb2 = gf->nodes[i]->nb[2];
|
||||||
|
|
||||||
|
const enum ggml_type src0t = gf->nodes[i]->src0->type;
|
||||||
|
const enum ggml_type src1t = gf->nodes[i]->src1->type;
|
||||||
|
const enum ggml_type dstt = gf->nodes[i]->type;
|
||||||
|
|
||||||
|
printf("mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02);
|
||||||
|
printf("mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12);
|
||||||
|
printf("mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
|
||||||
|
printf("mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
|
||||||
|
|
||||||
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
GGML_ASSERT(ne02 == ne12);
|
||||||
|
|
||||||
|
if (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) {
|
||||||
|
if (encoder != nil) {
|
||||||
|
[encoder endEncoding];
|
||||||
|
encoder = nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
|
||||||
|
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
|
||||||
|
|
||||||
|
// for F32 x F32 we use MPS
|
||||||
|
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
|
||||||
|
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:gf->nodes[i]->src0->nb[1] dataType:src0dt];
|
||||||
|
|
||||||
|
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
|
||||||
|
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:gf->nodes[i]->src1->nb[1] dataType:src1dt];
|
||||||
|
|
||||||
|
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
|
||||||
|
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
|
||||||
|
|
||||||
|
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
|
||||||
|
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
||||||
|
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
|
||||||
|
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||||
|
size_t offs_src0_cur = offs_src0 + i02*nb02;
|
||||||
|
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
||||||
|
size_t offs_dst_cur = offs_dst + i02*nb2;
|
||||||
|
|
||||||
|
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
|
||||||
|
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
|
||||||
|
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
|
||||||
|
|
||||||
|
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (encoder == nil) {
|
||||||
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
|
}
|
||||||
|
|
||||||
|
// for Q4 x F32 we use custom kernel
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
@ -416,9 +438,8 @@ int llama_mtl_eval(
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];
|
||||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ne00, ne01, ne10, ne11, ne0, ne1);
|
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
|
|
8
ggml.c
8
ggml.c
|
@ -14613,7 +14613,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou
|
||||||
const int64_t * ne = tensor->ne;
|
const int64_t * ne = tensor->ne;
|
||||||
const size_t * nb = tensor->nb;
|
const size_t * nb = tensor->nb;
|
||||||
|
|
||||||
fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %16s\n",
|
fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %32s\n",
|
||||||
ggml_type_name(tensor->type),
|
ggml_type_name(tensor->type),
|
||||||
ggml_op_name (tensor->op),
|
ggml_op_name (tensor->op),
|
||||||
tensor->n_dims,
|
tensor->n_dims,
|
||||||
|
@ -14627,7 +14627,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
|
||||||
const int64_t * ne = tensor->ne;
|
const int64_t * ne = tensor->ne;
|
||||||
const size_t * nb = tensor->nb;
|
const size_t * nb = tensor->nb;
|
||||||
|
|
||||||
fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %16s\n",
|
fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %32s\n",
|
||||||
arg,
|
arg,
|
||||||
ggml_type_name(tensor->type),
|
ggml_type_name(tensor->type),
|
||||||
ggml_op_name (tensor->op),
|
ggml_op_name (tensor->op),
|
||||||
|
@ -15067,6 +15067,10 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
||||||
{
|
{
|
||||||
tensor = ggml_transpose(*ctx_eval, args[0]);
|
tensor = ggml_transpose(*ctx_eval, args[0]);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_PERMUTE:
|
||||||
|
{
|
||||||
|
tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
|
tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
|
||||||
|
|
24
llama.cpp
24
llama.cpp
|
@ -1289,16 +1289,22 @@ static bool llama_eval_internal(
|
||||||
( n_ctx)*ggml_element_size(kv_self.v),
|
( n_ctx)*ggml_element_size(kv_self.v),
|
||||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||||
|
|
||||||
struct ggml_tensor * t = ggml_cpy(ctx0, Kcur, k);
|
//struct ggml_tensor * t = ggml_cpy(ctx0, Vcur, v);
|
||||||
// TODO: TMP !!!!
|
//// TODO: TMP !!!!
|
||||||
if (il == 0) {
|
//if (il == 0) {
|
||||||
ggml_set_name(t, "mtl-check");
|
// ggml_set_name(t, "mtl-check");
|
||||||
}
|
//}
|
||||||
|
|
||||||
// important: storing RoPE-ed version of K in the KV cache!
|
// important: storing RoPE-ed version of K in the KV cache!
|
||||||
//ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||||
ggml_build_forward_expand(&gf, t);
|
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||||
|
//ggml_build_forward_expand(&gf, t);
|
||||||
|
|
||||||
|
// TODO: TMP !!!!!!!!!!
|
||||||
|
if (il == 0) {
|
||||||
|
ggml_build_forward_expand(&gf_export, ggml_cpy(ctx0, Kcur, k));
|
||||||
|
ggml_build_forward_expand(&gf_export, ggml_cpy(ctx0, Vcur, v));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
|
@ -1318,6 +1324,10 @@ static bool llama_eval_internal(
|
||||||
// K * Q
|
// K * Q
|
||||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||||
ggml_set_name(KQ, "KQ");
|
ggml_set_name(KQ, "KQ");
|
||||||
|
// TODO: TMP !!!!
|
||||||
|
if (il == 0) {
|
||||||
|
ggml_set_name(KQ, "mtl-check");
|
||||||
|
}
|
||||||
|
|
||||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||||
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
|
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue