metal : add more general support for ggml_get_rows + tests

This commit is contained in:
Georgi Gerganov 2023-12-09 14:18:42 +02:00
parent 9064b1ca05
commit 2cbcba829f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 78 additions and 25 deletions

View file

@ -805,8 +805,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS:
case GGML_OP_CONCAT:
case GGML_OP_ADD:
case GGML_OP_MUL:
@ -828,7 +829,6 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_MUL_MAT_ID:
return true;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{
return op->ne[0] % 4 == 0;
}
@ -1568,16 +1568,18 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
const int64_t n = ggml_nelements(src1);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{