diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 51eb96fc9..03ec39d86 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1392,16 +1392,8 @@ struct hash_map { }; static const size_t HASH_MAP_SIZE = sizeof(struct hash_map); -struct hash_map * new_hash_map(struct ggml_context * ctx, struct ggml_tensor * * out_buf) { - struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, HASH_MAP_SIZE); - if (out_buf) { - * out_buf = buf; - } - struct hash_map * result = (struct hash_map *) ((char *) buf->data); - *result = (struct hash_map) { - /*.keys =*/ { NULL }, - /*.vals =*/ { NULL }, - }; +struct hash_map * new_hash_map() { + struct hash_map * result = new struct hash_map; for (int i=0; ikeys[i] = NULL; result->vals[i] = NULL; @@ -1409,6 +1401,10 @@ struct hash_map * new_hash_map(struct ggml_context * ctx, struct ggml_tensor * * return result; }; +void free_hash_map(struct hash_map * map) { + delete map; +} + struct ggml_tensor * ggml_recompute_graph_node( struct ggml_context * ctx, struct ggml_cgraph * graph, @@ -1471,7 +1467,7 @@ void ggml_build_backward_gradient_checkpointing( return; } - struct hash_map * replacements = new_hash_map(ctx, NULL); + struct hash_map * replacements = new_hash_map(); // insert checkpoints in replacements for (int i = 0; i < n_checkpoints; ++i) { @@ -1498,6 +1494,8 @@ void ggml_build_backward_gradient_checkpointing( // insert rewritten backward node with replacements made into resulting backward graph gb ggml_build_forward_expand(gb, node); } + + free_hash_map(replacements); } struct ggml_tensor * llama_build_train_graphs(