metal : reduce the kernel launches for ggml_mul_mat_id

This commit is contained in:
Georgi Gerganov 2023-12-09 15:30:34 +02:00
parent 7e2006b0c0
commit 8c5b66eeaa
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 49 additions and 27 deletions

View file

@ -1495,6 +1495,9 @@ void ggml_metal_graph_compute(
const int idx = ((int32_t *) dst->op_params)[0];
// batch size
GGML_ASSERT(ne01 == ne11);
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
@ -1515,19 +1518,25 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
[encoder setBytes:&idx length:sizeof(idx) atIndex:15];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
// TODO: how to make this an array? read Metal docs
for (int j = 0; j < n_as; ++j) {
struct ggml_tensor * src_cur = dst->src[2 + j];
@ -1535,18 +1544,19 @@ void ggml_metal_graph_compute(
size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) {
[encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake( (1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
//[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
//for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) {
// [encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0];
// [encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1];
// [encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
//}
}
} break;
case GGML_OP_GET_ROWS: