ggml : generalize GGML_OP_CONCAT (#7563)
* ggml : generalize GGML_OP_CONCAT (WIP) ggml-ci * tests : add dim != 2 tests * metal : generalize concat kernel * tests : naming * cuda : generalize concat kernel ggml-ci * sycl : add warning and assert * ggml : fix op params handling * metal : bugfix kernel ggml-ci * ggml : reimplement CPU and Metal * cuda : add asserts ggml-ci * ggml : fix ptrs ggml-ci
This commit is contained in:
parent
9335b969e8
commit
0548a4187f
7 changed files with 167 additions and 56 deletions
61
ggml.c
61
ggml.c
|
@ -4882,10 +4882,21 @@ struct ggml_tensor * ggml_repeat_back(
|
|||
// ggml_concat
|
||||
|
||||
struct ggml_tensor * ggml_concat(
|
||||
struct ggml_context* ctx,
|
||||
struct ggml_tensor* a,
|
||||
struct ggml_tensor* b) {
|
||||
GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int dim) {
|
||||
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
||||
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
|
||||
if (d == dim) {
|
||||
ne[d] = a->ne[d] + b->ne[d];
|
||||
continue;
|
||||
}
|
||||
GGML_ASSERT(a->ne[d] == b->ne[d]);
|
||||
ne[d] = a->ne[d];
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
|
@ -4893,7 +4904,9 @@ struct ggml_tensor * ggml_concat(
|
|||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, dim);
|
||||
|
||||
result->op = GGML_OP_CONCAT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -5013,6 +5026,7 @@ struct ggml_tensor * ggml_leaky_relu(
|
|||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
|
||||
|
||||
result->op = GGML_OP_LEAKY_RELU;
|
||||
|
@ -10967,26 +10981,29 @@ static void ggml_compute_forward_concat_f32(
|
|||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
|
||||
int64_t o[4] = {0, 0, 0, 0};
|
||||
o[dim] = src0->ne[dim];
|
||||
|
||||
const float * x;
|
||||
|
||||
// TODO: smarter multi-theading
|
||||
for (int i3 = 0; i3 < ne3; i3++) {
|
||||
for (int i2 = ith; i2 < ne2; i2 += nth) {
|
||||
if (i2 < ne02) { // src0
|
||||
for (int i1 = 0; i1 < ne1; i1++) {
|
||||
for (int i0 = 0; i0 < ne0; i0++) {
|
||||
const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
|
||||
|
||||
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
||||
*y = *x;
|
||||
for (int i1 = 0; i1 < ne1; i1++) {
|
||||
for (int i0 = 0; i0 < ne0; i0++) {
|
||||
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||
x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
|
||||
} else {
|
||||
x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
|
||||
}
|
||||
}
|
||||
} // src1
|
||||
else {
|
||||
for (int i1 = 0; i1 < ne1; i1++) {
|
||||
for (int i0 = 0; i0 < ne0; i0++) {
|
||||
const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
|
||||
|
||||
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
||||
*y = *x;
|
||||
}
|
||||
float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10994,7 +11011,7 @@ static void ggml_compute_forward_concat_f32(
|
|||
}
|
||||
|
||||
static void ggml_compute_forward_concat(
|
||||
const struct ggml_compute_params* params,
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor* dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue