ggml : add view_src and view_offs to ggml_tensor for views (#2874)

* ggml : add view_src and view_offs

* update ggml-alloc to use view_src

* update ggml_diag_mask to work correctly with automatic inplace

* exclude other ops that set an inplace flag from automatic inplace
This commit is contained in:
slaren 2023-08-29 23:24:42 +02:00 committed by GitHub
parent c03a243abf
commit 06abf8eeba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 178 deletions

View file

@ -321,8 +321,7 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
//////////// compute graph allocator
static bool ggml_is_view(struct ggml_tensor * t) {
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
return t->view_src != NULL;
}
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
@ -340,28 +339,6 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
return true;
}
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
switch (t->op) {
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
return t->src[0];
case GGML_OP_CPY:
return t->src[1];
default:
return NULL;
}
}
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
struct ggml_tensor * parent = t;
do {
parent = get_view_parent(parent);
} while (ggml_is_view(parent));
return parent;
}
static bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
case GGML_OP_SCALE:
@ -369,7 +346,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
@ -379,7 +355,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_UNARY:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
case GGML_OP_SET:
case GGML_OP_SOFT_MAX:
case GGML_OP_CONT:
return true;
@ -393,24 +368,8 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
struct hash_node * ht = alloc->hash_table;
if (node->data == NULL) {
if (ggml_is_view(node)) {
size_t offset;
switch(node->op) {
case GGML_OP_VIEW:
memcpy(&offset, node->op_params, sizeof(size_t));
node->data = (char *) node->src[0]->data + offset;
break;
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
node->data = node->src[0]->data;
break;
case GGML_OP_CPY:
node->data = node->src[1]->data;
break;
default:
GGML_ASSERT(!"unknown view op");
break;
}
assert(node->view_src->data != NULL);
node->data = (char *)node->view_src->data + node->view_offs;
} else {
// see if we can reuse a parent's buffer (inplace)
if (ggml_op_can_inplace(node->op)) {
@ -430,7 +389,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
struct hash_node * p_hn = hash_get(ht, parent);
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = get_view_source(parent);
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = hash_get(ht, view_src);
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
@ -472,7 +431,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
struct ggml_tensor * node = gf->nodes[i];
if (ggml_is_view(node)) {
struct ggml_tensor * view_src = get_view_source(node);
struct ggml_tensor * view_src = node->view_src;
hash_get(ht, view_src)->n_views += 1;
}
@ -557,7 +516,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = get_view_source(parent);
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = hash_get(ht, view_src);
view_src_hn->n_views -= 1;
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);