add ggml_reshape_1d, ggml_reshape_4d and ggml_view_4d
This commit is contained in:
parent
f1d51d144b
commit
b4c273f7a3
2 changed files with 117 additions and 0 deletions
92
ggml.c
92
ggml.c
|
@ -5924,6 +5924,30 @@ struct ggml_tensor * ggml_reshape(
|
|||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_reshape_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(ggml_nelements(a) == ne0);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
const int64_t ne[1] = { ne0 };
|
||||
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a->data);
|
||||
|
||||
result->op = GGML_OP_RESHAPE;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = NULL;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_reshape_2d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
@ -5975,6 +5999,34 @@ struct ggml_tensor * ggml_reshape_3d(
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
struct ggml_tensor * ggml_reshape_4d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3) {
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
|
||||
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a->data);
|
||||
|
||||
result->op = GGML_OP_RESHAPE;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = NULL;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_view_1d
|
||||
|
||||
struct ggml_tensor * ggml_view_1d(
|
||||
|
@ -6077,6 +6129,46 @@ struct ggml_tensor * ggml_view_3d(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_view_4d
|
||||
|
||||
struct ggml_tensor * ggml_view_4d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3,
|
||||
size_t nb1,
|
||||
size_t nb2,
|
||||
size_t nb3,
|
||||
size_t offset) {
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 };
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
|
||||
|
||||
result->nb[1] = nb1;
|
||||
result->nb[2] = nb2;
|
||||
result->nb[3] = nb3;
|
||||
|
||||
result->op = GGML_OP_VIEW;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = NULL;
|
||||
|
||||
if (is_node) {
|
||||
memcpy(result->padding, &offset, sizeof(offset));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_permute
|
||||
|
||||
struct ggml_tensor * ggml_permute(
|
||||
|
|
25
ggml.h
25
ggml.h
|
@ -649,6 +649,11 @@ extern "C" {
|
|||
|
||||
// return view(a)
|
||||
// TODO: when we start computing gradient, make a copy instead of view
|
||||
GGML_API struct ggml_tensor * ggml_reshape_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_reshape_2d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
@ -664,6 +669,14 @@ extern "C" {
|
|||
int64_t ne1,
|
||||
int64_t ne2);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_reshape_4d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3);
|
||||
|
||||
// offset in bytes
|
||||
GGML_API struct ggml_tensor * ggml_view_1d(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -689,6 +702,18 @@ extern "C" {
|
|||
size_t nb2, // slice stride in bytes
|
||||
size_t offset);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_view_4d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3,
|
||||
size_t nb1, // row stride in bytes
|
||||
size_t nb2, // slice stride in bytes
|
||||
size_t nb3,
|
||||
size_t offset);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_permute(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue