From b4c273f7a320104db0b81b9f45d936fa9266b951 Mon Sep 17 00:00:00 2001 From: xaedes Date: Sat, 6 May 2023 17:29:41 +0200 Subject: [PATCH] add ggml_reshape_1d, ggml_reshape_4d and ggml_view_4d --- ggml.c | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ggml.h | 25 ++++++++++++++++ 2 files changed, 117 insertions(+) diff --git a/ggml.c b/ggml.c index ce7fda0e7..cd4f54bf1 100644 --- a/ggml.c +++ b/ggml.c @@ -5924,6 +5924,30 @@ struct ggml_tensor * ggml_reshape( return result; } +struct ggml_tensor * ggml_reshape_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[1] = { ne0 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + struct ggml_tensor * ggml_reshape_2d( struct ggml_context * ctx, struct ggml_tensor * a, @@ -5975,6 +5999,34 @@ struct ggml_tensor * ggml_reshape_3d( return result; } + +struct ggml_tensor * ggml_reshape_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + // ggml_view_1d struct ggml_tensor * ggml_view_1d( @@ -6077,6 +6129,46 @@ struct ggml_tensor * ggml_view_3d( return result; } +// ggml_view_4d + +struct ggml_tensor * ggml_view_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 }; + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset); + + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = nb3; + + result->op = GGML_OP_VIEW; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + if (is_node) { + memcpy(result->padding, &offset, sizeof(offset)); + } + + return result; +} + // ggml_permute struct ggml_tensor * ggml_permute( diff --git a/ggml.h b/ggml.h index 832281566..ae1c82e15 100644 --- a/ggml.h +++ b/ggml.h @@ -649,6 +649,11 @@ extern "C" { // return view(a) // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + GGML_API struct ggml_tensor * ggml_reshape_2d( struct ggml_context * ctx, struct ggml_tensor * a, @@ -664,6 +669,14 @@ extern "C" { int64_t ne1, int64_t ne2); + GGML_API struct ggml_tensor * ggml_reshape_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + // offset in bytes GGML_API struct ggml_tensor * ggml_view_1d( struct ggml_context * ctx, @@ -689,6 +702,18 @@ extern "C" { size_t nb2, // slice stride in bytes size_t offset); + GGML_API struct ggml_tensor * ggml_view_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t nb3, + size_t offset); + GGML_API struct ggml_tensor * ggml_permute( struct ggml_context * ctx, struct ggml_tensor * a,