From c993246bfd1c21240ebf7b477a9774c10a451371 Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 17 Sep 2023 17:52:22 +0200 Subject: [PATCH] train-text-from-scratch: automatically allocate compute memory --- .../train-text-from-scratch.cpp | 177 ++++++++++-------- 1 file changed, 94 insertions(+), 83 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 80cf2e255..ecb71c0ae 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1081,13 +1081,6 @@ int main(int argc, char ** argv) { printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f)); printf("%s: opt iter %d\n", __func__, opt->iter); - // TODO: use std::vector intead of "new" - size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb); - uint8_t * compute_addr = new uint8_t[compute_size]; - - size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb); - uint8_t * compute_buf_0 = new uint8_t[size_buf_0]; - int n_tokens = model.hparams.n_ctx; int n_vocab = model.hparams.n_vocab; int n_batch = params.common.n_batch; @@ -1124,9 +1117,82 @@ int main(int argc, char ** argv) { ggml_allocr_alloc(alloc, target_probs); ggml_allocr_free(alloc); - if (params.use_alloc) { - alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment); + // context for compute tensors without their data + size_t estimated_compute_size_wo_data = ( + ggml_tensor_overhead()*GGML_MAX_NODES*2 + + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*( + params.common.use_checkpointing ? 3 : 2 + ) + ); + struct ggml_init_params ctx_compute_params = { + estimated_compute_size_wo_data, // mem_size + NULL, // mem_buffer + true, // no_alloc + }; + struct ggml_context * ctx_compute = NULL; + + struct ggml_tensor * loss = NULL; + struct ggml_tensor * logits = NULL; + + struct ggml_cgraph * gf = NULL; + struct ggml_cgraph * gb = NULL; + struct ggml_cgraph * gb_tmp = NULL; + + // measure required memory for compute tensors + size_t best_compute_size = SIZE_MAX; + enum ggml_cgraph_eval_order best_order = GGML_CGRAPH_EVAL_ORDER_COUNT; + // find best evaluation order + for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) { + ctx_compute = ggml_init(ctx_compute_params); + alloc = ggml_allocr_new_measure(tensor_alignment); + gf = ggml_new_graph(ctx_compute); + gf->order = (enum ggml_cgraph_eval_order) order; + gb = ggml_new_graph(ctx_compute); + gb_tmp = params.common.use_checkpointing + ? ggml_new_graph(ctx_compute) + : NULL; + loss = llama_build_train_graphs( + &model, alloc, ctx_compute, + gf, gb, gb_tmp, + &logits, tokens_input, target_probs, + n_tokens, n_batch, + params.common.use_flash, + params.common.use_checkpointing + ); + size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment; + if (max_compute_size < best_compute_size) { + best_compute_size = max_compute_size; + best_order = gf->order; + } + ggml_allocr_free(alloc); + ggml_free(ctx_compute); } + size_t max_compute_size = best_compute_size; + printf("%s: compute_size = %zu bytes (%.1f MB)\n", __func__, max_compute_size, (float) max_compute_size / (1024.0f*1024.0f)); + printf("%s: evaluation order = %s\n", __func__, + (best_order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? "LEFT_TO_RIGHT" : + (best_order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? "RIGHT_TO_LEFT" : + "invalid"); + + // allocate compute tensors + mem_compute_data.resize(max_compute_size); + ctx_compute = ggml_init(ctx_compute_params); + alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment); + gf = ggml_new_graph(ctx_compute); + gf->order = best_order; + gb = ggml_new_graph(ctx_compute); + gb_tmp = params.common.use_checkpointing + ? ggml_new_graph(ctx_compute) + : NULL; + loss = llama_build_train_graphs( + &model, alloc, ctx_compute, + gf, gb, gb_tmp, + &logits, tokens_input, target_probs, + n_tokens, n_batch, + params.common.use_flash, + params.common.use_checkpointing + ); + ggml_allocr_free(alloc); std::vector train_tokens; std::vector train_samples_begin; @@ -1204,83 +1270,30 @@ int main(int argc, char ** argv) { opt_cb_data.last_time = ggml_time_ms(); opt_cb_data.millis_per_iter = 0.0; + // measure required memory for work buffer + size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE; + printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f)); + + // context for work buffer + struct ggml_init_params ctx_work_params = { + max_work_size, // mem_size + NULL, // mem_buffer + false, // no_alloc + }; + struct ggml_context * ctx_work = ggml_init(ctx_work_params); + int64_t t0 = ggml_time_ms(); - for (int ex = 0; ex < params.n_examples; ++ex) { + ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data); - struct ggml_init_params cparams = { - compute_size, // mem_size - compute_addr, // mem_buffer - false, // no_alloc - }; - struct ggml_context * ctx0 = ggml_init(cparams); - - ggml_set_no_alloc(ctx0, (alloc != NULL)); - - if (alloc) { - ggml_allocr_reset(alloc); - } - - int n_past = 0; - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - struct ggml_cgraph * gb = ggml_new_graph(ctx0); - struct ggml_cgraph * gb_tmp = params.common.use_checkpointing - ? ggml_new_graph(ctx0) - : NULL; - - GGML_ASSERT(n_past == 0); - - struct ggml_tensor * loss = NULL; - struct ggml_tensor * logits = NULL; - - loss = llama_build_train_graphs( - &model, alloc, ctx0, - gf, gb, gb_tmp, - &logits, tokens_input, target_probs, - n_tokens, n_batch, - params.common.use_flash, - params.common.use_checkpointing - ); - - size_t used_mem_before_opt = ggml_used_mem(ctx0); - - opt->params.adam.sched = learning_schedule( - opt->iter, - params.common.warmup, - params.common.cos_decay_steps, - params.common.adam_alpha, - params.common.adam_min_alpha, - params.common.cos_decay_min, - params.common.cos_decay_restart, - params.common.enable_restart); - - printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); - - ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data); - - size_t used_mem_after_opt = ggml_used_mem(ctx0); - - int n_iter = params.common.adam_n_iter; - train->train_its = opt->iter; - train->train_samples += n_batch * n_iter; - train->train_tokens += n_batch * n_tokens * n_iter; - - if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { - printf("Example %d, opt iter %d\n", ex, opt->iter); - printf("error_before_opt: %.6f\n", opt->loss_before); - printf("error_after_opt: %.6f\n", opt->loss_after); - printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt); - printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt); - } - - ggml_free(ctx0); - } + ggml_free(ctx_work); + ggml_free(ctx_compute); + ggml_free(ctx_input); int64_t t1 = ggml_time_ms(); - int64_t d = t1-t0; - double dd = (double) d * 1e-3; - printf("%s: total training time=%f seconds\n", __func__, dd); + printf("%s: total training time: ", __func__); + print_duration((double) (t1 - t0)); + printf("\n"); int new_iters = opt->iter - opt_cb_data.last_save_iter; if (new_iters > 0) { @@ -1295,8 +1308,6 @@ int main(int argc, char ** argv) { ggml_allocr_free(alloc); } - delete[] compute_addr; - delete[] compute_buf_0; ggml_free(opt->ctx); free_train_state(train); ggml_free(model.ctx);