add list of ops that support in-place

This commit is contained in:
slaren 2023-07-27 16:11:32 +02:00
parent 8fa548377a
commit f67179aaf2

View file

@ -305,6 +305,33 @@ static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
return parent;
}
bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
case GGML_OP_SCALE:
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
case GGML_OP_UNARY:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
case GGML_OP_SET:
case GGML_OP_SOFT_MAX:
case GGML_OP_CONT:
return true;
default:
return false;
}
}
static void allocate_node(struct ggml_allocator * alloc, struct ggml_tensor * node) {
if (node->data == NULL) {
if (ggml_is_view(node)) {
@ -333,8 +360,7 @@ static void allocate_node(struct ggml_allocator * alloc, struct ggml_tensor * no
if (parent == NULL) {
break;
}
// TODO: make a list of operations that can be safely made inplace
if (parent->data != NULL && parent->n_children == 1 && parent->n_views == 0 && ggml_are_same_layout(node, parent) && node->op != GGML_OP_MUL_MAT) {
if (parent->data != NULL && parent->n_children == 1 && parent->n_views == 0 && ggml_are_same_layout(node, parent) && ggml_op_can_inplace(node->op)) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = get_view_source(parent);
if (view_src->n_views == 1 && view_src->n_children == 0 && view_src->data == parent->data) {