ggml : generalize GGML_OP_CONCAT (#7563)
* ggml : generalize GGML_OP_CONCAT (WIP) ggml-ci * tests : add dim != 2 tests * metal : generalize concat kernel * tests : naming * cuda : generalize concat kernel ggml-ci * sycl : add warning and assert * ggml : fix op params handling * metal : bugfix kernel ggml-ci * ggml : reimplement CPU and Metal * cuda : add asserts ggml-ci * ggml : fix ptrs ggml-ci
This commit is contained in:
parent
9335b969e8
commit
0548a4187f
7 changed files with 167 additions and 56 deletions
|
@ -990,6 +990,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
||||
|
||||
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
|
@ -1018,6 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue