add option to use scratch buffers in training or not
make it configurable because currently training with scratch buffers implies flash attention and optimization over all parameters.
This commit is contained in:
parent
0d4b87de3d
commit
d9626743ac
1 changed files with 34 additions and 31 deletions
|
@ -2614,6 +2614,7 @@ struct train_params {
|
|||
bool samples_start_after_nl;
|
||||
bool use_adam;
|
||||
bool use_flash;
|
||||
bool use_scratch;
|
||||
|
||||
// only adam
|
||||
int warmup;
|
||||
|
@ -2661,6 +2662,7 @@ struct train_params get_default_train_params() {
|
|||
params.samples_start_after_nl = false;
|
||||
params.use_adam = true;
|
||||
params.use_flash = true;
|
||||
params.use_scratch = true;
|
||||
|
||||
// only adam
|
||||
params.warmup = 100;
|
||||
|
@ -2710,6 +2712,8 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
|
||||
fprintf(stderr, " --no-flash Don't use flash attention.\n");
|
||||
fprintf(stderr, " --use-flash Use flash attention (default)\n");
|
||||
fprintf(stderr, " --no-scratch Don't use scratch buffers\n");
|
||||
fprintf(stderr, " --use-scratch Use scratch buffers (default)\n");
|
||||
fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup);
|
||||
fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
||||
fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
||||
|
@ -2856,6 +2860,10 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
params->use_flash = false;
|
||||
} else if (arg == "--use-flash") {
|
||||
params->use_flash = true;
|
||||
} else if (arg == "--no-scratch") {
|
||||
params->use_scratch = false;
|
||||
} else if (arg == "--use-scratch") {
|
||||
params->use_scratch = true;
|
||||
} else if (arg == "--warmup") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -3146,38 +3154,36 @@ int main(int argc, char ** argv) {
|
|||
|
||||
get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
|
||||
|
||||
// struct ggml_tensor * logits =
|
||||
// (n_past == 0)
|
||||
// ? (params.use_flash
|
||||
// ? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)
|
||||
// : forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch))
|
||||
// : forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
GGML_ASSERT(n_past == 0);
|
||||
|
||||
// struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs);
|
||||
struct ggml_tensor * logits;
|
||||
struct ggml_tensor * e = forward_batch_wo_cache_flash_attn_train(
|
||||
&model,
|
||||
ctx0,
|
||||
gf,
|
||||
gb,
|
||||
&logits,
|
||||
tokens_input,
|
||||
target_probs,
|
||||
compute_buf_0,
|
||||
compute_buf_1,
|
||||
compute_buf_2,
|
||||
size_buf_0,
|
||||
size_buf_1,
|
||||
size_buf_2,
|
||||
n_tokens,
|
||||
n_batch);
|
||||
struct ggml_tensor * loss = NULL;
|
||||
struct ggml_tensor * logits = NULL;
|
||||
|
||||
if (params.use_scratch) {
|
||||
loss = forward_batch_wo_cache_flash_attn_train(
|
||||
&model, ctx0,
|
||||
gf, gb,
|
||||
&logits, tokens_input, target_probs,
|
||||
compute_buf_0, compute_buf_1, compute_buf_2,
|
||||
size_buf_0, size_buf_1, size_buf_2,
|
||||
n_tokens, n_batch);
|
||||
} else if (params.use_flash) {
|
||||
logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
|
||||
loss = cross_entropy_loss(ctx0, logits, target_probs);
|
||||
ggml_build_forward_expand(gf, loss);
|
||||
*gb = ggml_build_backward(ctx0, gf, true);
|
||||
} else {
|
||||
logits = forward_batch_wo_cache(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
|
||||
loss = cross_entropy_loss(ctx0, logits, target_probs);
|
||||
ggml_build_forward_expand(gf, loss);
|
||||
*gb = ggml_build_backward(ctx0, gf, true);
|
||||
}
|
||||
|
||||
// ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, gf);
|
||||
|
||||
size_t used_mem_before_opt = ggml_used_mem(ctx0);
|
||||
|
||||
float error_before_opt = ggml_get_f32_1d(e, 0);
|
||||
float error_before_opt = ggml_get_f32_1d(loss, 0);
|
||||
|
||||
opt->params.adam.sched = (opt->iter < params.warmup)
|
||||
? (float) opt->iter / (float) params.warmup
|
||||
|
@ -3189,9 +3195,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
|
||||
|
||||
// ggml_opt(ctx0, opt->params, e);
|
||||
// ggml_opt_resume(ctx0, opt, e);
|
||||
ggml_opt_resume_g(ctx0, opt, e, gf, gb);
|
||||
ggml_opt_resume_g(ctx0, opt, loss, gf, gb);
|
||||
|
||||
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
||||
|
||||
|
@ -3199,10 +3203,9 @@ int main(int argc, char ** argv) {
|
|||
model.train_samples += n_batch;
|
||||
model.train_tokens += n_batch * n_tokens;
|
||||
|
||||
//ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, gf);
|
||||
|
||||
float error_after_opt = ggml_get_f32_1d(e, 0);
|
||||
float error_after_opt = ggml_get_f32_1d(loss, 0);
|
||||
|
||||
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
|
||||
printf("Example %d, opt iter %d\n", ex, opt->iter);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue