metal : generalize concat kernel

This commit is contained in:
Georgi Gerganov 2024-05-27 14:33:30 +03:00
parent 0347657a3b
commit acdc075b60
No known key found for this signature in database
GPG key ID: BF970631944C16B7
3 changed files with 36 additions and 18 deletions

View file

@ -990,6 +990,8 @@ static enum ggml_status ggml_metal_graph_compute(
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
const int32_t dim = ((int32_t *) dst->op_params)[0];
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -1018,6 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
const int nth = MIN(1024, ne0);

View file

@ -3366,31 +3366,42 @@ kernel void kernel_concat(
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int32_t & dim,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig.z;
const int64_t i02 = tgpig.y;
const int64_t i01 = tgpig.x;
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
device const char * src;
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
int64_t o[4] = {0, 0, 0, 0};
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
src = src0;
o[dim] = 0;
} else {
src = src1;
o[dim] = ne00;
}
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
if (i02 < ne02) {
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
src0_ptr += ntg.x*nb00;
if (dim == 0) {
if (i0 < ne00) {
src = src0;
o[dim] = 0;
} else {
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
src1_ptr += ntg.x*nb10;
src = src1;
o[dim] = ne00;
}
dst_ptr += ntg.x*nb0;
}
device const float * x = (device const float *)(src + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
device float * y = (device float *)(dst + (i3 )*nb3 + (i2 )*nb2 + (i1 )*nb1 + (i0 )*nb0);
*y = *x;
}
}

8
ggml.c
View file

@ -10982,10 +10982,13 @@ static void ggml_compute_forward_concat_f32(
int32_t dim;
memcpy(&dim, dst->op_params, sizeof(int32_t));
GGML_ASSERT(dim >= 0 && dim < 4);
const char * src;
int64_t o[4] = {0, 0, 0, 0};
// TODO: smarter multi-theading
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ith; i2 < ne2; i2 += nth) {
for (int i1 = 0; i1 < ne1; i1++) {
@ -10997,9 +11000,10 @@ static void ggml_compute_forward_concat_f32(
src = (const char *) src1->data;
o[dim] = src0->ne[dim];
}
const float * x = (const float *)(src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13);
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
const float * x = (const float *)( src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13);
float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3);
*y = *x;
}
}