diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index d2451bdca..65501c355 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -694,10 +694,12 @@ struct ggml_tensor * llama_build_lora_finetune_graphs( struct ggml_tensor * cur = t01; std::vector checkpoints; - checkpoints.push_back(tokens_input); - checkpoints.push_back(targets); - checkpoints.push_back(t00); - checkpoints.push_back(t01); + if (enable_checkpointing) { + checkpoints.push_back(tokens_input); + checkpoints.push_back(targets); + checkpoints.push_back(t00); + checkpoints.push_back(t01); + } struct ggml_tensor * kv_scale = NULL; if (!enable_flash_attn) { @@ -766,7 +768,9 @@ struct ggml_tensor * llama_build_lora_finetune_graphs( struct ggml_tensor * t29 = ggml_mul_mat (ctx, w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch); struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch); cur = t30; - checkpoints.push_back(cur); + if (enable_checkpointing) { + checkpoints.push_back(cur); + } } struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch); struct ggml_tensor * t32 = ggml_repeat (ctx, norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch); @@ -775,12 +779,14 @@ struct ggml_tensor * llama_build_lora_finetune_graphs( struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch); struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1); - checkpoints.push_back(t31); - checkpoints.push_back(t32); - checkpoints.push_back(t33); - checkpoints.push_back(t34); - checkpoints.push_back(t35); - checkpoints.push_back(t36); + if (enable_checkpointing) { + checkpoints.push_back(t31); + checkpoints.push_back(t32); + checkpoints.push_back(t33); + checkpoints.push_back(t34); + checkpoints.push_back(t35); + checkpoints.push_back(t36); + } ggml_build_forward_expand(gf, t36);