correctly clone view tensors by setting data pointers
without this the checkpointing would only work when being used together with memory allocator
This commit is contained in:
parent
52c92c0a8c
commit
345f516f7c
1 changed files with 32 additions and 0 deletions
|
@ -1405,6 +1405,33 @@ void free_hash_map(struct hash_map * map) {
|
|||
delete map;
|
||||
}
|
||||
|
||||
static bool ggml_is_view(struct ggml_tensor * t) {
|
||||
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
|
||||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_VIEW:
|
||||
return t->src[0];
|
||||
case GGML_OP_CPY:
|
||||
return t->src[1];
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
||||
struct ggml_tensor * parent = t;
|
||||
do {
|
||||
parent = get_view_parent(parent);
|
||||
} while (ggml_is_view(parent));
|
||||
return parent;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_recompute_graph_node(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * graph,
|
||||
|
@ -1457,6 +1484,11 @@ struct ggml_tensor * ggml_recompute_graph_node(
|
|||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
||||
}
|
||||
if (ggml_is_view(clone)) {
|
||||
struct ggml_tensor * source = get_view_source(clone);
|
||||
GGML_ASSERT(source != NULL);
|
||||
clone->data = source->data;
|
||||
}
|
||||
|
||||
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
||||
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue