add option to use scratch buffers in training or not

make it configurable because currently training with scratch buffers implies flash attention and optimization over all parameters.
This commit is contained in:
xaedes 2023-06-01 20:59:19 +02:00
parent 0d4b87de3d
commit d9626743ac
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -2614,6 +2614,7 @@ struct train_params {
bool samples_start_after_nl; bool samples_start_after_nl;
bool use_adam; bool use_adam;
bool use_flash; bool use_flash;
bool use_scratch;
// only adam // only adam
int warmup; int warmup;
@ -2661,6 +2662,7 @@ struct train_params get_default_train_params() {
params.samples_start_after_nl = false; params.samples_start_after_nl = false;
params.use_adam = true; params.use_adam = true;
params.use_flash = true; params.use_flash = true;
params.use_scratch = true;
// only adam // only adam
params.warmup = 100; params.warmup = 100;
@ -2710,6 +2712,8 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n"); fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
fprintf(stderr, " --no-flash Don't use flash attention.\n"); fprintf(stderr, " --no-flash Don't use flash attention.\n");
fprintf(stderr, " --use-flash Use flash attention (default)\n"); fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-scratch Don't use scratch buffers\n");
fprintf(stderr, " --use-scratch Use scratch buffers (default)\n");
fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup); fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps); fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
@ -2856,6 +2860,10 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
params->use_flash = false; params->use_flash = false;
} else if (arg == "--use-flash") { } else if (arg == "--use-flash") {
params->use_flash = true; params->use_flash = true;
} else if (arg == "--no-scratch") {
params->use_scratch = false;
} else if (arg == "--use-scratch") {
params->use_scratch = true;
} else if (arg == "--warmup") { } else if (arg == "--warmup") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -3146,38 +3154,36 @@ int main(int argc, char ** argv) {
get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs); get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
// struct ggml_tensor * logits = GGML_ASSERT(n_past == 0);
// (n_past == 0)
// ? (params.use_flash
// ? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)
// : forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch))
// : forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
// struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs); struct ggml_tensor * loss = NULL;
struct ggml_tensor * logits; struct ggml_tensor * logits = NULL;
struct ggml_tensor * e = forward_batch_wo_cache_flash_attn_train(
&model, if (params.use_scratch) {
ctx0, loss = forward_batch_wo_cache_flash_attn_train(
gf, &model, ctx0,
gb, gf, gb,
&logits, &logits, tokens_input, target_probs,
tokens_input, compute_buf_0, compute_buf_1, compute_buf_2,
target_probs, size_buf_0, size_buf_1, size_buf_2,
compute_buf_0, n_tokens, n_batch);
compute_buf_1, } else if (params.use_flash) {
compute_buf_2, logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
size_buf_0, loss = cross_entropy_loss(ctx0, logits, target_probs);
size_buf_1, ggml_build_forward_expand(gf, loss);
size_buf_2, *gb = ggml_build_backward(ctx0, gf, true);
n_tokens, } else {
n_batch); logits = forward_batch_wo_cache(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
loss = cross_entropy_loss(ctx0, logits, target_probs);
ggml_build_forward_expand(gf, loss);
*gb = ggml_build_backward(ctx0, gf, true);
}
// ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, gf); ggml_graph_compute(ctx0, gf);
size_t used_mem_before_opt = ggml_used_mem(ctx0); size_t used_mem_before_opt = ggml_used_mem(ctx0);
float error_before_opt = ggml_get_f32_1d(e, 0); float error_before_opt = ggml_get_f32_1d(loss, 0);
opt->params.adam.sched = (opt->iter < params.warmup) opt->params.adam.sched = (opt->iter < params.warmup)
? (float) opt->iter / (float) params.warmup ? (float) opt->iter / (float) params.warmup
@ -3189,9 +3195,7 @@ int main(int argc, char ** argv) {
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
// ggml_opt(ctx0, opt->params, e); ggml_opt_resume_g(ctx0, opt, loss, gf, gb);
// ggml_opt_resume(ctx0, opt, e);
ggml_opt_resume_g(ctx0, opt, e, gf, gb);
size_t used_mem_after_opt = ggml_used_mem(ctx0); size_t used_mem_after_opt = ggml_used_mem(ctx0);
@ -3199,10 +3203,9 @@ int main(int argc, char ** argv) {
model.train_samples += n_batch; model.train_samples += n_batch;
model.train_tokens += n_batch * n_tokens; model.train_tokens += n_batch * n_tokens;
//ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, gf); ggml_graph_compute(ctx0, gf);
float error_after_opt = ggml_get_f32_1d(e, 0); float error_after_opt = ggml_get_f32_1d(loss, 0);
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
printf("Example %d, opt iter %d\n", ex, opt->iter); printf("Example %d, opt iter %d\n", ex, opt->iter);