diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index ee17bd8e4..ff6167da8 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2614,6 +2614,7 @@ struct train_params { bool samples_start_after_nl; bool use_adam; bool use_flash; + bool use_scratch; // only adam int warmup; @@ -2661,6 +2662,7 @@ struct train_params get_default_train_params() { params.samples_start_after_nl = false; params.use_adam = true; params.use_flash = true; + params.use_scratch = true; // only adam params.warmup = 100; @@ -2710,6 +2712,8 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --use-adam Use Adam optimizer (default)\n"); fprintf(stderr, " --no-flash Don't use flash attention.\n"); fprintf(stderr, " --use-flash Use flash attention (default)\n"); + fprintf(stderr, " --no-scratch Don't use scratch buffers\n"); + fprintf(stderr, " --use-scratch Use scratch buffers (default)\n"); fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup); fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps); fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); @@ -2856,6 +2860,10 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { params->use_flash = false; } else if (arg == "--use-flash") { params->use_flash = true; + } else if (arg == "--no-scratch") { + params->use_scratch = false; + } else if (arg == "--use-scratch") { + params->use_scratch = true; } else if (arg == "--warmup") { if (++i >= argc) { invalid_param = true; @@ -3146,38 +3154,36 @@ int main(int argc, char ** argv) { get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs); - // struct ggml_tensor * logits = - // (n_past == 0) - // ? (params.use_flash - // ? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch) - // : forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)) - // : forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); + GGML_ASSERT(n_past == 0); - // struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs); - struct ggml_tensor * logits; - struct ggml_tensor * e = forward_batch_wo_cache_flash_attn_train( - &model, - ctx0, - gf, - gb, - &logits, - tokens_input, - target_probs, - compute_buf_0, - compute_buf_1, - compute_buf_2, - size_buf_0, - size_buf_1, - size_buf_2, - n_tokens, - n_batch); + struct ggml_tensor * loss = NULL; + struct ggml_tensor * logits = NULL; + + if (params.use_scratch) { + loss = forward_batch_wo_cache_flash_attn_train( + &model, ctx0, + gf, gb, + &logits, tokens_input, target_probs, + compute_buf_0, compute_buf_1, compute_buf_2, + size_buf_0, size_buf_1, size_buf_2, + n_tokens, n_batch); + } else if (params.use_flash) { + logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch); + loss = cross_entropy_loss(ctx0, logits, target_probs); + ggml_build_forward_expand(gf, loss); + *gb = ggml_build_backward(ctx0, gf, true); + } else { + logits = forward_batch_wo_cache(&model, ctx0, gf, tokens_input, n_tokens, n_batch); + loss = cross_entropy_loss(ctx0, logits, target_probs); + ggml_build_forward_expand(gf, loss); + *gb = ggml_build_backward(ctx0, gf, true); + } - // ggml_build_forward_expand(&gf, e); ggml_graph_compute(ctx0, gf); size_t used_mem_before_opt = ggml_used_mem(ctx0); - float error_before_opt = ggml_get_f32_1d(e, 0); + float error_before_opt = ggml_get_f32_1d(loss, 0); opt->params.adam.sched = (opt->iter < params.warmup) ? (float) opt->iter / (float) params.warmup @@ -3189,9 +3195,7 @@ int main(int argc, char ** argv) { printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); - // ggml_opt(ctx0, opt->params, e); - // ggml_opt_resume(ctx0, opt, e); - ggml_opt_resume_g(ctx0, opt, e, gf, gb); + ggml_opt_resume_g(ctx0, opt, loss, gf, gb); size_t used_mem_after_opt = ggml_used_mem(ctx0); @@ -3199,10 +3203,9 @@ int main(int argc, char ** argv) { model.train_samples += n_batch; model.train_tokens += n_batch * n_tokens; - //ggml_build_forward_expand(&gf, e); ggml_graph_compute(ctx0, gf); - float error_after_opt = ggml_get_f32_1d(e, 0); + float error_after_opt = ggml_get_f32_1d(loss, 0); if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { printf("Example %d, opt iter %d\n", ex, opt->iter);