diff --git a/ggml-metal.metal b/ggml-metal.metal index 335b990d9..b16f2b7e0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3375,31 +3375,19 @@ kernel void kernel_concat( const int64_t i2 = tgpig.y; const int64_t i1 = tgpig.x; - device const char * src; - int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - if (dim > 0 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - src = src0; - o[dim] = 0; - } else { - src = src1; - o[dim] = dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03); - } + device const float * x; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (dim == 0) { - if (i0 < ne00) { - src = src0; - o[dim] = 0; - } else { - src = src1; - o[dim] = ne00; - } + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - device const float * x = (device const float *)(src + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); - device float * y = (device float *)(dst + (i3 )*nb3 + (i2 )*nb2 + (i1 )*nb1 + (i0 )*nb0); + device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); *y = *x; } diff --git a/ggml.c b/ggml.c index 18c87d7d6..7f9961127 100644 --- a/ggml.c +++ b/ggml.c @@ -10985,9 +10985,10 @@ static void ggml_compute_forward_concat_f32( GGML_ASSERT(dim >= 0 && dim < 4); - const char * src; - int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const float * x; // TODO: smarter multi-theading for (int i3 = 0; i3 < ne3; i3++) { @@ -10995,15 +10996,12 @@ static void ggml_compute_forward_concat_f32( for (int i1 = 0; i1 < ne1; i1++) { for (int i0 = 0; i0 < ne0; i0++) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - src = (const char *) src0->data; - o[dim] = 0; + x = (const float *) (src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); } else { - src = (const char *) src1->data; - o[dim] = src0->ne[dim]; + x = (const float *) (src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); } - const float * x = (const float *)( src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); - float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3); + float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3); *y = *x; }