add tensor checkpoints only when gradient checkpointing is enabled
This commit is contained in:
parent
e0da1684db
commit
4914f855c7
1 changed files with 17 additions and 11 deletions
|
@ -694,10 +694,12 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
struct ggml_tensor * cur = t01;
|
||||
|
||||
std::vector<struct ggml_tensor *> checkpoints;
|
||||
checkpoints.push_back(tokens_input);
|
||||
checkpoints.push_back(targets);
|
||||
checkpoints.push_back(t00);
|
||||
checkpoints.push_back(t01);
|
||||
if (enable_checkpointing) {
|
||||
checkpoints.push_back(tokens_input);
|
||||
checkpoints.push_back(targets);
|
||||
checkpoints.push_back(t00);
|
||||
checkpoints.push_back(t01);
|
||||
}
|
||||
|
||||
struct ggml_tensor * kv_scale = NULL;
|
||||
if (!enable_flash_attn) {
|
||||
|
@ -766,7 +768,9 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
struct ggml_tensor * t29 = ggml_mul_mat (ctx, w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
|
||||
cur = t30;
|
||||
checkpoints.push_back(cur);
|
||||
if (enable_checkpointing) {
|
||||
checkpoints.push_back(cur);
|
||||
}
|
||||
}
|
||||
struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t32 = ggml_repeat (ctx, norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch);
|
||||
|
@ -775,12 +779,14 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch);
|
||||
struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1);
|
||||
|
||||
checkpoints.push_back(t31);
|
||||
checkpoints.push_back(t32);
|
||||
checkpoints.push_back(t33);
|
||||
checkpoints.push_back(t34);
|
||||
checkpoints.push_back(t35);
|
||||
checkpoints.push_back(t36);
|
||||
if (enable_checkpointing) {
|
||||
checkpoints.push_back(t31);
|
||||
checkpoints.push_back(t32);
|
||||
checkpoints.push_back(t33);
|
||||
checkpoints.push_back(t34);
|
||||
checkpoints.push_back(t35);
|
||||
checkpoints.push_back(t36);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, t36);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue