add list of ops that support in-place
This commit is contained in:
parent
8fa548377a
commit
f67179aaf2
1 changed files with 28 additions and 2 deletions
30
ggml-alloc.c
30
ggml-alloc.c
|
@ -305,6 +305,33 @@ static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
||||||
return parent;
|
return parent;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ggml_op_can_inplace(enum ggml_op op) {
|
||||||
|
switch (op) {
|
||||||
|
case GGML_OP_SCALE:
|
||||||
|
case GGML_OP_DIAG_MASK_ZERO:
|
||||||
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD1:
|
||||||
|
case GGML_OP_ACC:
|
||||||
|
case GGML_OP_SUB:
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_DIV:
|
||||||
|
case GGML_OP_SQR:
|
||||||
|
case GGML_OP_SQRT:
|
||||||
|
case GGML_OP_LOG:
|
||||||
|
case GGML_OP_UNARY:
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_SET:
|
||||||
|
case GGML_OP_SOFT_MAX:
|
||||||
|
case GGML_OP_CONT:
|
||||||
|
return true;
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void allocate_node(struct ggml_allocator * alloc, struct ggml_tensor * node) {
|
static void allocate_node(struct ggml_allocator * alloc, struct ggml_tensor * node) {
|
||||||
if (node->data == NULL) {
|
if (node->data == NULL) {
|
||||||
if (ggml_is_view(node)) {
|
if (ggml_is_view(node)) {
|
||||||
|
@ -333,8 +360,7 @@ static void allocate_node(struct ggml_allocator * alloc, struct ggml_tensor * no
|
||||||
if (parent == NULL) {
|
if (parent == NULL) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// TODO: make a list of operations that can be safely made inplace
|
if (parent->data != NULL && parent->n_children == 1 && parent->n_views == 0 && ggml_are_same_layout(node, parent) && ggml_op_can_inplace(node->op)) {
|
||||||
if (parent->data != NULL && parent->n_children == 1 && parent->n_views == 0 && ggml_are_same_layout(node, parent) && node->op != GGML_OP_MUL_MAT) {
|
|
||||||
if (ggml_is_view(parent)) {
|
if (ggml_is_view(parent)) {
|
||||||
struct ggml_tensor * view_src = get_view_source(parent);
|
struct ggml_tensor * view_src = get_view_source(parent);
|
||||||
if (view_src->n_views == 1 && view_src->n_children == 0 && view_src->data == parent->data) {
|
if (view_src->n_views == 1 && view_src->n_children == 0 && view_src->data == parent->data) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue