ggml : reimplement CPU and Metal

This commit is contained in:
Georgi Gerganov 2024-05-27 18:22:05 +03:00
parent 9d8f12ce4f
commit 94a0c7650d
No known key found for this signature in database
GPG key ID: BF970631944C16B7
2 changed files with 13 additions and 27 deletions

View file

@ -3375,31 +3375,19 @@ kernel void kernel_concat(
const int64_t i2 = tgpig.y; const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x; const int64_t i1 = tgpig.x;
device const char * src;
int64_t o[4] = {0, 0, 0, 0}; int64_t o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
if (dim > 0 && i1 < ne01 && i2 < ne02 && i3 < ne03) { device const float * x;
src = src0;
o[dim] = 0;
} else {
src = src1;
o[dim] = dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03);
}
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
if (dim == 0) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
if (i0 < ne00) { x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
src = src0; } else {
o[dim] = 0; x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
} else {
src = src1;
o[dim] = ne00;
}
} }
device const float * x = (device const float *)(src + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
device float * y = (device float *)(dst + (i3 )*nb3 + (i2 )*nb2 + (i1 )*nb1 + (i0 )*nb0);
*y = *x; *y = *x;
} }

14
ggml.c
View file

@ -10985,9 +10985,10 @@ static void ggml_compute_forward_concat_f32(
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);
const char * src;
int64_t o[4] = {0, 0, 0, 0}; int64_t o[4] = {0, 0, 0, 0};
o[dim] = src0->ne[dim];
const float * x;
// TODO: smarter multi-theading // TODO: smarter multi-theading
for (int i3 = 0; i3 < ne3; i3++) { for (int i3 = 0; i3 < ne3; i3++) {
@ -10995,15 +10996,12 @@ static void ggml_compute_forward_concat_f32(
for (int i1 = 0; i1 < ne1; i1++) { for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) { for (int i0 = 0; i0 < ne0; i0++) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
src = (const char *) src0->data; x = (const float *) (src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
o[dim] = 0;
} else { } else {
src = (const char *) src1->data; x = (const float *) (src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
o[dim] = src0->ne[dim];
} }
const float * x = (const float *)( src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3);
float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3);
*y = *x; *y = *x;
} }