diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 48edf3651..88a1c3a50 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1585,8 +1585,8 @@ struct ggml_tensor * llama_build_train_graphs( checkpoints.push_back(t01); struct ggml_tensor * kv_scale; - if (flash_attn) { - kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head))); + if (!enable_flash_attn) { + kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head)); } for (int il = 0; il < n_layer; ++il) {