metal : fix ggml_get_rows to work with non-cont src1

This commit is contained in:
Georgi Gerganov 2023-12-10 09:38:21 +02:00
parent 0710b0f726
commit 016f9bb55a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 54 additions and 30 deletions

View file

@ -1584,11 +1584,12 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
const int64_t n = ggml_nelements(src1);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{