correctly clone view tensors by setting data pointers

without this the checkpointing would only work when being used together with memory allocator
This commit is contained in:
xaedes 2023-08-14 17:55:13 +02:00
parent 52c92c0a8c
commit 345f516f7c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1405,6 +1405,33 @@ void free_hash_map(struct hash_map * map) {
delete map;
}
static bool ggml_is_view(struct ggml_tensor * t) {
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
}
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
switch (t->op) {
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
return t->src[0];
case GGML_OP_CPY:
return t->src[1];
default:
return NULL;
}
}
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
struct ggml_tensor * parent = t;
do {
parent = get_view_parent(parent);
} while (ggml_is_view(parent));
return parent;
}
struct ggml_tensor * ggml_recompute_graph_node(
struct ggml_context * ctx,
struct ggml_cgraph * graph,
@ -1457,6 +1484,11 @@ struct ggml_tensor * ggml_recompute_graph_node(
for (int k = 0; k < GGML_MAX_SRC; ++k) {
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
}
if (ggml_is_view(clone)) {
struct ggml_tensor * source = get_view_source(clone);
GGML_ASSERT(source != NULL);
clone->data = source->data;
}
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);