From 4914f855c7dfcb23fc9fe98c9cd3329f7215c630 Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 31 Aug 2023 16:46:21 +0200 Subject: [PATCH] add tensor checkpoints only when gradient checkpointing is enabled --- examples/finetune/finetune.cpp | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index d2451bdca..65501c355 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -694,10 +694,12 @@ struct ggml_tensor * llama_build_lora_finetune_graphs( struct ggml_tensor * cur = t01; std::vector checkpoints; - checkpoints.push_back(tokens_input); - checkpoints.push_back(targets); - checkpoints.push_back(t00); - checkpoints.push_back(t01); + if (enable_checkpointing) { + checkpoints.push_back(tokens_input); + checkpoints.push_back(targets); + checkpoints.push_back(t00); + checkpoints.push_back(t01); + } struct ggml_tensor * kv_scale = NULL; if (!enable_flash_attn) { @@ -766,7 +768,9 @@ struct ggml_tensor * llama_build_lora_finetune_graphs( struct ggml_tensor * t29 = ggml_mul_mat (ctx, w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch); struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch); cur = t30; - checkpoints.push_back(cur); + if (enable_checkpointing) { + checkpoints.push_back(cur); + } } struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch); struct ggml_tensor * t32 = ggml_repeat (ctx, norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch); @@ -775,12 +779,14 @@ struct ggml_tensor * llama_build_lora_finetune_graphs( struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch); struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1); - checkpoints.push_back(t31); - checkpoints.push_back(t32); - checkpoints.push_back(t33); - checkpoints.push_back(t34); - checkpoints.push_back(t35); - checkpoints.push_back(t36); + if (enable_checkpointing) { + checkpoints.push_back(t31); + checkpoints.push_back(t32); + checkpoints.push_back(t33); + checkpoints.push_back(t34); + checkpoints.push_back(t35); + checkpoints.push_back(t36); + } ggml_build_forward_expand(gf, t36);