diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 9b73361ca..410ba69b9 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1405,6 +1405,33 @@ void free_hash_map(struct hash_map * map) { delete map; } +static bool ggml_is_view(struct ggml_tensor * t) { + return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE || + t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY; +} + +static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) { + switch (t->op) { + case GGML_OP_PERMUTE: + case GGML_OP_RESHAPE: + case GGML_OP_TRANSPOSE: + case GGML_OP_VIEW: + return t->src[0]; + case GGML_OP_CPY: + return t->src[1]; + default: + return NULL; + } +} + +static struct ggml_tensor * get_view_source(struct ggml_tensor * t) { + struct ggml_tensor * parent = t; + do { + parent = get_view_parent(parent); + } while (ggml_is_view(parent)); + return parent; +} + struct ggml_tensor * ggml_recompute_graph_node( struct ggml_context * ctx, struct ggml_cgraph * graph, @@ -1457,6 +1484,11 @@ struct ggml_tensor * ggml_recompute_graph_node( for (int k = 0; k < GGML_MAX_SRC; ++k) { clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); } + if (ggml_is_view(clone)) { + struct ggml_tensor * source = get_view_source(clone); + GGML_ASSERT(source != NULL); + clone->data = source->data; + } GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);