add ggml_reshape_1d, ggml_reshape_4d and ggml_view_4d

This commit is contained in:
xaedes 2023-05-06 17:29:41 +02:00
parent f1d51d144b
commit b4c273f7a3
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 117 additions and 0 deletions

92
ggml.c
View file

@ -5924,6 +5924,30 @@ struct ggml_tensor * ggml_reshape(
return result; return result;
} }
struct ggml_tensor * ggml_reshape_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0) {
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(ggml_nelements(a) == ne0);
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[1] = { ne0 };
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a->data);
result->op = GGML_OP_RESHAPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;
return result;
}
struct ggml_tensor * ggml_reshape_2d( struct ggml_tensor * ggml_reshape_2d(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -5975,6 +5999,34 @@ struct ggml_tensor * ggml_reshape_3d(
return result; return result;
} }
struct ggml_tensor * ggml_reshape_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3) {
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a->data);
result->op = GGML_OP_RESHAPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;
return result;
}
// ggml_view_1d // ggml_view_1d
struct ggml_tensor * ggml_view_1d( struct ggml_tensor * ggml_view_1d(
@ -6077,6 +6129,46 @@ struct ggml_tensor * ggml_view_3d(
return result; return result;
} }
// ggml_view_4d
struct ggml_tensor * ggml_view_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3,
size_t nb1,
size_t nb2,
size_t nb3,
size_t offset) {
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 };
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
result->nb[1] = nb1;
result->nb[2] = nb2;
result->nb[3] = nb3;
result->op = GGML_OP_VIEW;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;
if (is_node) {
memcpy(result->padding, &offset, sizeof(offset));
}
return result;
}
// ggml_permute // ggml_permute
struct ggml_tensor * ggml_permute( struct ggml_tensor * ggml_permute(

25
ggml.h
View file

@ -649,6 +649,11 @@ extern "C" {
// return view(a) // return view(a)
// TODO: when we start computing gradient, make a copy instead of view // TODO: when we start computing gradient, make a copy instead of view
GGML_API struct ggml_tensor * ggml_reshape_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0);
GGML_API struct ggml_tensor * ggml_reshape_2d( GGML_API struct ggml_tensor * ggml_reshape_2d(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -664,6 +669,14 @@ extern "C" {
int64_t ne1, int64_t ne1,
int64_t ne2); int64_t ne2);
GGML_API struct ggml_tensor * ggml_reshape_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3);
// offset in bytes // offset in bytes
GGML_API struct ggml_tensor * ggml_view_1d( GGML_API struct ggml_tensor * ggml_view_1d(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -689,6 +702,18 @@ extern "C" {
size_t nb2, // slice stride in bytes size_t nb2, // slice stride in bytes
size_t offset); size_t offset);
GGML_API struct ggml_tensor * ggml_view_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3,
size_t nb1, // row stride in bytes
size_t nb2, // slice stride in bytes
size_t nb3,
size_t offset);
GGML_API struct ggml_tensor * ggml_permute( GGML_API struct ggml_tensor * ggml_permute(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,