ggml : generalize GGML_OP_CONCAT (WIP)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-27 14:14:56 +03:00
parent 1d8fca72ae
commit 0b52245bdb
No known key found for this signature in database
GPG key ID: BF970631944C16B7
3 changed files with 40 additions and 25 deletions

52
ggml.c
View file

@ -4882,10 +4882,19 @@ struct ggml_tensor * ggml_repeat_back(
// ggml_concat
struct ggml_tensor * ggml_concat(
struct ggml_context* ctx,
struct ggml_tensor* a,
struct ggml_tensor* b) {
GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int dim) {
int64_t ne[GGML_MAX_DIMS];
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
if (d == dim) {
ne[d] = a->ne[d] + b->ne[d];
continue;
}
GGML_ASSERT(a->ne[d] == b->ne[d]);
ne[d] = a->ne[d];
}
bool is_node = false;
@ -4893,7 +4902,9 @@ struct ggml_tensor * ggml_concat(
is_node = true;
}
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
ggml_set_op_params(result, &dim, sizeof(dim));
result->op = GGML_OP_CONCAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -5013,6 +5024,7 @@ struct ggml_tensor * ggml_leaky_relu(
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
result->op = GGML_OP_LEAKY_RELU;
@ -10967,34 +10979,36 @@ static void ggml_compute_forward_concat_f32(
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
int32_t dim;
memcpy(&dim, dst->op_params, sizeof(int32_t));
const char * src;
int64_t o[4] = {0, 0, 0, 0};
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ith; i2 < ne2; i2 += nth) {
if (i2 < ne02) { // src0
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
src = (const char *) src0->data;
o[dim] = 0;
} else {
src = (const char *) src1->data;
o[dim] = src0->ne[dim];
}
const float * x = (const float *)(src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13);
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
*y = *x;
}
}
} // src1
else {
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
*y = *x;
}
}
}
}
}
}
static void ggml_compute_forward_concat(
const struct ggml_compute_params* params,
const struct ggml_compute_params * params,
struct ggml_tensor* dst) {
const struct ggml_tensor * src0 = dst->src[0];

5
ggml.h
View file

@ -1007,12 +1007,13 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// concat a and b on dim 2
// concat a and b along dim
// used in stable-diffusion
GGML_API struct ggml_tensor * ggml_concat(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * b,
int dim);
GGML_API struct ggml_tensor * ggml_abs(
struct ggml_context * ctx,

View file

@ -1274,7 +1274,7 @@ struct test_concat : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], b_ne2, ne[3]);
ggml_tensor * out = ggml_concat(ctx, a, b);
ggml_tensor * out = ggml_concat(ctx, a, b, 2);
return out;
}
};