metal : add more general support for ggml_get_rows + tests
This commit is contained in:
parent
9064b1ca05
commit
2cbcba829f
4 changed files with 78 additions and 25 deletions
16
ggml-metal.m
16
ggml-metal.m
|
@ -805,8 +805,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_MUL:
|
||||
|
@ -828,7 +829,6 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|||
case GGML_OP_MUL_MAT_ID:
|
||||
return true;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
return op->ne[0] % 4 == 0;
|
||||
}
|
||||
|
@ -1568,16 +1568,18 @@ void ggml_metal_graph_compute(
|
|||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
|
||||
|
||||
const int64_t n = ggml_nelements(src1);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue