metal : generalize concat kernel
This commit is contained in:
parent
0347657a3b
commit
acdc075b60
3 changed files with 36 additions and 18 deletions
|
@ -990,6 +990,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
||||||
|
|
||||||
|
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
@ -1018,6 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||||
|
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
||||||
|
|
||||||
const int nth = MIN(1024, ne0);
|
const int nth = MIN(1024, ne0);
|
||||||
|
|
||||||
|
|
|
@ -3366,31 +3366,42 @@ kernel void kernel_concat(
|
||||||
constant uint64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint64_t & nb2,
|
constant uint64_t & nb2,
|
||||||
constant uint64_t & nb3,
|
constant uint64_t & nb3,
|
||||||
|
constant int32_t & dim,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
const int64_t i03 = tgpig.z;
|
const int64_t i3 = tgpig.z;
|
||||||
const int64_t i02 = tgpig.y;
|
const int64_t i2 = tgpig.y;
|
||||||
const int64_t i01 = tgpig.x;
|
const int64_t i1 = tgpig.x;
|
||||||
|
|
||||||
const int64_t i13 = i03 % ne13;
|
device const char * src;
|
||||||
const int64_t i12 = i02 % ne12;
|
|
||||||
const int64_t i11 = i01 % ne11;
|
|
||||||
|
|
||||||
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
int64_t o[4] = {0, 0, 0, 0};
|
||||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
|
||||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||||
|
src = src0;
|
||||||
|
o[dim] = 0;
|
||||||
|
} else {
|
||||||
|
src = src1;
|
||||||
|
o[dim] = ne00;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||||
if (i02 < ne02) {
|
if (dim == 0) {
|
||||||
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
if (i0 < ne00) {
|
||||||
src0_ptr += ntg.x*nb00;
|
src = src0;
|
||||||
|
o[dim] = 0;
|
||||||
} else {
|
} else {
|
||||||
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
src = src1;
|
||||||
src1_ptr += ntg.x*nb10;
|
o[dim] = ne00;
|
||||||
}
|
}
|
||||||
dst_ptr += ntg.x*nb0;
|
}
|
||||||
|
|
||||||
|
device const float * x = (device const float *)(src + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
||||||
|
device float * y = (device float *)(dst + (i3 )*nb3 + (i2 )*nb2 + (i1 )*nb1 + (i0 )*nb0);
|
||||||
|
|
||||||
|
*y = *x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
8
ggml.c
8
ggml.c
|
@ -10982,10 +10982,13 @@ static void ggml_compute_forward_concat_f32(
|
||||||
int32_t dim;
|
int32_t dim;
|
||||||
memcpy(&dim, dst->op_params, sizeof(int32_t));
|
memcpy(&dim, dst->op_params, sizeof(int32_t));
|
||||||
|
|
||||||
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
|
|
||||||
const char * src;
|
const char * src;
|
||||||
|
|
||||||
int64_t o[4] = {0, 0, 0, 0};
|
int64_t o[4] = {0, 0, 0, 0};
|
||||||
|
|
||||||
|
// TODO: smarter multi-theading
|
||||||
for (int i3 = 0; i3 < ne3; i3++) {
|
for (int i3 = 0; i3 < ne3; i3++) {
|
||||||
for (int i2 = ith; i2 < ne2; i2 += nth) {
|
for (int i2 = ith; i2 < ne2; i2 += nth) {
|
||||||
for (int i1 = 0; i1 < ne1; i1++) {
|
for (int i1 = 0; i1 < ne1; i1++) {
|
||||||
|
@ -10997,9 +11000,10 @@ static void ggml_compute_forward_concat_f32(
|
||||||
src = (const char *) src1->data;
|
src = (const char *) src1->data;
|
||||||
o[dim] = src0->ne[dim];
|
o[dim] = src0->ne[dim];
|
||||||
}
|
}
|
||||||
const float * x = (const float *)(src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13);
|
|
||||||
|
|
||||||
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
const float * x = (const float *)( src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13);
|
||||||
|
float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3);
|
||||||
|
|
||||||
*y = *x;
|
*y = *x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue