ggml : reimplement CPU and Metal
This commit is contained in:
parent
9d8f12ce4f
commit
94a0c7650d
2 changed files with 13 additions and 27 deletions
|
@ -3375,31 +3375,19 @@ kernel void kernel_concat(
|
||||||
const int64_t i2 = tgpig.y;
|
const int64_t i2 = tgpig.y;
|
||||||
const int64_t i1 = tgpig.x;
|
const int64_t i1 = tgpig.x;
|
||||||
|
|
||||||
device const char * src;
|
|
||||||
|
|
||||||
int64_t o[4] = {0, 0, 0, 0};
|
int64_t o[4] = {0, 0, 0, 0};
|
||||||
|
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
||||||
|
|
||||||
if (dim > 0 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
device const float * x;
|
||||||
src = src0;
|
|
||||||
o[dim] = 0;
|
|
||||||
} else {
|
|
||||||
src = src1;
|
|
||||||
o[dim] = dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||||
if (dim == 0) {
|
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||||
if (i0 < ne00) {
|
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
|
||||||
src = src0;
|
} else {
|
||||||
o[dim] = 0;
|
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
||||||
} else {
|
|
||||||
src = src1;
|
|
||||||
o[dim] = ne00;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
device const float * x = (device const float *)(src + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
device float * y = (device float *)(dst + (i3 )*nb3 + (i2 )*nb2 + (i1 )*nb1 + (i0 )*nb0);
|
|
||||||
|
|
||||||
*y = *x;
|
*y = *x;
|
||||||
}
|
}
|
||||||
|
|
14
ggml.c
14
ggml.c
|
@ -10985,9 +10985,10 @@ static void ggml_compute_forward_concat_f32(
|
||||||
|
|
||||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
|
|
||||||
const char * src;
|
|
||||||
|
|
||||||
int64_t o[4] = {0, 0, 0, 0};
|
int64_t o[4] = {0, 0, 0, 0};
|
||||||
|
o[dim] = src0->ne[dim];
|
||||||
|
|
||||||
|
const float * x;
|
||||||
|
|
||||||
// TODO: smarter multi-theading
|
// TODO: smarter multi-theading
|
||||||
for (int i3 = 0; i3 < ne3; i3++) {
|
for (int i3 = 0; i3 < ne3; i3++) {
|
||||||
|
@ -10995,15 +10996,12 @@ static void ggml_compute_forward_concat_f32(
|
||||||
for (int i1 = 0; i1 < ne1; i1++) {
|
for (int i1 = 0; i1 < ne1; i1++) {
|
||||||
for (int i0 = 0; i0 < ne0; i0++) {
|
for (int i0 = 0; i0 < ne0; i0++) {
|
||||||
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||||
src = (const char *) src0->data;
|
x = (const float *) (src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
|
||||||
o[dim] = 0;
|
|
||||||
} else {
|
} else {
|
||||||
src = (const char *) src1->data;
|
x = (const float *) (src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
|
||||||
o[dim] = src0->ne[dim];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const float * x = (const float *)( src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13);
|
float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3);
|
||||||
float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3);
|
|
||||||
|
|
||||||
*y = *x;
|
*y = *x;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue