correctly clone reshape and permute operations by also cloning tensor->nb values

This commit is contained in:
xaedes 2023-08-14 17:52:15 +02:00
parent d43741540b
commit cfddc36be2
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1440,6 +1440,9 @@ struct ggml_tensor * ggml_recompute_graph_node(
clone->grad = node->grad; clone->grad = node->grad;
clone->is_param = node->is_param; clone->is_param = node->is_param;
clone->extra = node->extra; clone->extra = node->extra;
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
clone->nb[k] = node->nb[k];
}
for (int k = 0; k < GGML_MAX_SRC; ++k) { for (int k = 0; k < GGML_MAX_SRC; ++k) {
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
} }