diff --git a/ggml-metal.m b/ggml-metal.m index ff9ae55aa..4ba498e87 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -990,6 +990,8 @@ static enum ggml_status ggml_metal_graph_compute( { id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + const int32_t dim = ((int32_t *) dst->op_params)[0]; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1018,6 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; const int nth = MIN(1024, ne0); diff --git a/ggml-metal.metal b/ggml-metal.metal index 174086b5b..342fa3707 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3366,31 +3366,42 @@ kernel void kernel_concat( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, + constant int32_t & dim, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + device const char * src; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + int64_t o[4] = {0, 0, 0, 0}; + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + src = src0; + o[dim] = 0; + } else { + src = src1; + o[dim] = ne00; + } for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i02 < ne02) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; - src0_ptr += ntg.x*nb00; - } else { - ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; - src1_ptr += ntg.x*nb10; + if (dim == 0) { + if (i0 < ne00) { + src = src0; + o[dim] = 0; + } else { + src = src1; + o[dim] = ne00; + } } - dst_ptr += ntg.x*nb0; + + device const float * x = (device const float *)(src + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + device float * y = (device float *)(dst + (i3 )*nb3 + (i2 )*nb2 + (i1 )*nb1 + (i0 )*nb0); + + *y = *x; } } diff --git a/ggml.c b/ggml.c index c66a43740..c5fd8fcab 100644 --- a/ggml.c +++ b/ggml.c @@ -10982,10 +10982,13 @@ static void ggml_compute_forward_concat_f32( int32_t dim; memcpy(&dim, dst->op_params, sizeof(int32_t)); + GGML_ASSERT(dim >= 0 && dim < 4); + const char * src; int64_t o[4] = {0, 0, 0, 0}; + // TODO: smarter multi-theading for (int i3 = 0; i3 < ne3; i3++) { for (int i2 = ith; i2 < ne2; i2 += nth) { for (int i1 = 0; i1 < ne1; i1++) { @@ -10997,9 +11000,10 @@ static void ggml_compute_forward_concat_f32( src = (const char *) src1->data; o[dim] = src0->ne[dim]; } - const float * x = (const float *)(src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); - float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + const float * x = (const float *)( src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); + float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3); + *y = *x; } }