add option to train with flash attention and move options to the top of the main function

training from scratch also works with flash attention
training convergence and generation results after fix number of iterations are worse than when not using flash attention.
maybe there still lingers a bug in the flash attention backward pass?
but training works, just with slower convergence.

flash attention is still worth to use, because it requires way less memory and is faster with high n_ctx
This commit is contained in:
xaedes 2023-05-30 13:17:58 +02:00
parent 70c08318af
commit fcbc4457d6
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1168,6 +1168,239 @@ struct ggml_tensor * forward_batch_wo_cache(
return inpL; return inpL;
} }
struct ggml_tensor * forward_batch_wo_cache_flash_attn(
struct my_llama_model * model,
struct ggml_context * ctx0,
struct ggml_cgraph * gf,
struct ggml_tensor * tokens_input,
const int n_tokens,
const int n_batch) {
const int n_past = 0;
const int N = n_tokens;
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_head = hparams.n_head;
const int n_rot = hparams.n_rot;
const int n_ff = get_n_ff(&hparams);
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
// inpL shape [n_embd,N*n_batch,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
assert_shape_2d(inpL, n_embd, N*n_batch);
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
struct ggml_tensor * cur;
// lctx.use_buf(ctx0, 0);
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur
cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
cur);
assert_shape_2d(cur, n_embd, N*n_batch);
}
// self-attention
{
// compute Q and K and RoPE them
// wq shape [n_embd, n_embd, 1, 1]
// wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
// Vcur shape [N, n_batch, n_embd/n_head, n_head]
struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head);
assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head);
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Q shape [n_embd/n_head, N, n_head, n_batch]
struct ggml_tensor * Q =
ggml_permute(ctx0,
Qcur,
0, 2, 1, 3);
assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch);
// kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
// K shape [n_embd/n_head, N, n_head, n_batch]
struct ggml_tensor * K =
ggml_permute(ctx0,
Kcur,
0, 2, 1, 3);
assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch);
// // K * Q
// // KQ shape [N, N, n_head, n_batch]
// struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// assert_shape_4d(KQ, N, N, n_head, n_batch);
// // KQ_scaled = KQ / sqrt(n_embd/n_head)
// // KQ_scaled shape [N, N, n_head, n_batch]
// struct ggml_tensor * KQ_scaled =
// ggml_scale_inplace(ctx0,
// KQ,
// ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
// assert_shape_4d(KQ_scaled, N, N, n_head, n_batch);
// // KQ_masked = mask_past(KQ_scaled)
// // KQ_masked shape [N, N, n_head, n_batch]
// struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
// assert_shape_4d(KQ_masked, N, N, n_head, n_batch);
// // KQ = soft_max(KQ_masked)
// // KQ_soft_max shape [N, N, n_head, n_batch]
// struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
// assert_shape_4d(KQ_soft_max, N, N, n_head, n_batch);
// Vcur shape [N, n_batch, n_embd/n_head, n_head]
// V shape [N, n_embd/n_head, n_head, n_batch]
struct ggml_tensor * V =
ggml_permute(ctx0,
Vcur,
0, 3, 1, 2);
assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch);
// // KQV shape [n_embd/n_head, N, n_head, n_batch]
// struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
// assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
bool masked = true;
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, masked);
assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// KQV_merged shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch);
// KQV_merged shape
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch);
assert_shape_2d(cur, n_embd, N*n_batch);
// projection (no bias)
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_mul_mat(ctx0,
model->layers[il].wo,
cur);
assert_shape_2d(cur, n_embd, N*n_batch);
}
// lctx.use_buf(ctx0, 1);
// inpFF shape [n_embd,N*n_batch,1,1]
struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA);
assert_shape_2d(inpFF, n_embd, N*n_batch);
// feed-forward network
{
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
cur);
assert_shape_2d(cur, n_embd, N*n_batch);
}
// tmp shape [n_ff,N*n_batch,1,1]
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
model->layers[il].w3,
cur);
assert_shape_2d(tmp, n_ff, N*n_batch);
// cur shape [n_ff,N*n_batch,1,1]
cur = ggml_mul_mat(ctx0,
model->layers[il].w1,
cur);
assert_shape_2d(cur, n_ff, N*n_batch);
// SILU activation
// cur shape [n_ff,N*n_batch,1,1]
cur = ggml_silu(ctx0, cur);
assert_shape_2d(cur, n_ff, N*n_batch);
// cur shape [n_ff,N*n_batch,1,1]
cur = ggml_mul(ctx0, cur, tmp);
assert_shape_2d(cur, n_ff, N*n_batch);
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_mul_mat(ctx0,
model->layers[il].w2,
cur);
assert_shape_2d(cur, n_embd, N*n_batch);
}
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_add_inplace(ctx0, cur, inpFF);
assert_shape_2d(cur, n_embd, N*n_batch);
// input for next layer
// inpL shape [n_embd,N*n_batch,1,1]
inpL = cur;
assert_shape_2d(inpL, n_embd, N*n_batch);
}
// norm
{
// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL
// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model->norm, inpL),
inpL);
assert_shape_2d(inpL, n_embd, N*n_batch);
//embeddings = inpL;
}
// lm_head
// inpL shape [n_vocab,N*n_batch,1,1]
inpL = ggml_mul_mat(ctx0, model->output, inpL);
assert_shape_2d(inpL, n_vocab, N*n_batch);
{
// inpL shape [n_vocab,N,n_batch,1]
inpL = ggml_reshape_3d(ctx0,
inpL,
n_vocab, N, n_batch);
assert_shape_3d(inpL, n_vocab, N, n_batch);
}
// run the computation
ggml_build_forward_expand(gf, inpL);
return inpL;
}
void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) { void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) {
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
*ptr = value; *ptr = value;
@ -1644,7 +1877,7 @@ void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
} }
std::string name = file->read_string(name_len); std::string name = file->read_string(name_len);
GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)) == 0); GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)-1) == 0);
file->seek(-file->tell() & 31, SEEK_CUR); file->seek(-file->tell() & 31, SEEK_CUR);
file->read_raw(tensor->data, ggml_nbytes(tensor)); file->read_raw(tensor->data, ggml_nbytes(tensor));
@ -1930,7 +2163,42 @@ int main(int argc, char ** argv) {
//return 1; //return 1;
} }
srand(time(NULL)); int seed = 1;
int n_ctx = 256;
// int n_ctx = 64;
int n_embd = 256;
int n_mult = 256;
int n_head = 8;
int n_layer = 16;
int n_rotmax = 64;
int n_threads = 6;
int n_batch = 8;
int n_examples = 32;
int print_info_interval = 1;
int print_details_interval = 2;
bool samples_start_after_nl = false;
bool use_adam = true;
bool use_flash = false;
// only adam
int warmup = 100;
int cos_decay_steps = 1000;
float cos_decay_restart = 1.1f;
float cos_decay_alpha = 0.0f;
int lbfgs_n_iter = 16;
int adam_n_iter = 16;
float adam_alpha = 1e-3;
float adam_decay = 1e-3;
if (seed < 0) {
srand(time(NULL));
} else {
srand(seed);
}
const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1]; const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1];
const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2]; const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2];
@ -1971,12 +2239,12 @@ int main(int argc, char ** argv) {
struct my_llama_model model; struct my_llama_model model;
model.hparams.n_vocab = llama_n_vocab(lctx); model.hparams.n_vocab = llama_n_vocab(lctx);
model.hparams.n_ctx = 32; model.hparams.n_ctx = n_ctx;
model.hparams.n_embd = 128; model.hparams.n_embd = n_embd;
model.hparams.n_mult = 64; model.hparams.n_mult = n_mult;
model.hparams.n_head = 16; model.hparams.n_head = n_head;
model.hparams.n_layer = 4; model.hparams.n_layer = n_layer;
model.hparams.n_rot = std::min(64u, model.hparams.n_embd / model.hparams.n_head); model.hparams.n_rot = std::min((uint32_t)n_rotmax, model.hparams.n_embd / model.hparams.n_head);
print_params(&model.hparams); print_params(&model.hparams);
@ -2011,18 +2279,6 @@ int main(int argc, char ** argv) {
my_llama_sampler sampler; my_llama_sampler sampler;
int n_threads = 6;
int n_batch = 32;
int n_examples = 32;
bool samples_start_after_nl = false;
bool use_adam = true;
int warmup = 100;
int cos_decay_steps = 1000;
float cos_decay_restart = 1.1f;
float cos_decay_alpha = 0.0f;
int n_tokens = model.hparams.n_ctx; int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab; int n_vocab = model.hparams.n_vocab;
@ -2035,15 +2291,15 @@ int main(int argc, char ** argv) {
opt_params_adam.print_forward_graph = false; opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false; opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = n_threads; opt_params_adam.n_threads = n_threads;
opt_params_adam.adam.n_iter = 16; opt_params_adam.adam.n_iter = adam_n_iter;
opt_params_adam.adam.sched = 1.0f; opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = 1e-3; opt_params_adam.adam.alpha = adam_alpha;
opt_params_adam.adam.decay = 1e-3; opt_params_adam.adam.decay = adam_decay;
opt_params_lbfgs.print_forward_graph = false; opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false; opt_params_lbfgs.print_backward_graph = false;
opt_params_lbfgs.n_threads = n_threads; opt_params_lbfgs.n_threads = n_threads;
opt_params_lbfgs.lbfgs.n_iter = 16; opt_params_lbfgs.lbfgs.n_iter = lbfgs_n_iter;
opt->ctx = model.ctx; opt->ctx = model.ctx;
opt->params = use_adam ? opt_params_adam : opt_params_lbfgs; opt->params = use_adam ? opt_params_adam : opt_params_lbfgs;
@ -2117,7 +2373,9 @@ int main(int argc, char ** argv) {
struct ggml_tensor * logits = struct ggml_tensor * logits =
(n_past == 0) (n_past == 0)
? forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch) ? (use_flash
? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)
: forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch))
: forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); : forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs); struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs);
@ -2148,16 +2406,16 @@ int main(int argc, char ** argv) {
float error_after_opt = ggml_get_f32_1d(e, 0); float error_after_opt = ggml_get_f32_1d(e, 0);
printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
if (ex % 1 == 0) { if (ex % print_info_interval == 0) {
printf("Example %d, opt iter %d\n", ex, opt->iter); printf("Example %d, opt iter %d\n", ex, opt->iter);
printf("error_before_opt: %.6f\n", error_before_opt); printf("error_before_opt: %.6f\n", error_before_opt);
printf("error_after_opt: %.6f\n", error_after_opt); printf("error_after_opt: %.6f\n", error_after_opt);
printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
} }
if (ex % 2 == 0) { if (ex % print_details_interval == 0) {
// set_logits_masked(logits, token_notavail, -1e9); // set_logits_masked(logits, token_notavail, -1e9);
for (int i=0; i<n_batch; ++i) { for (int i=0; i<n_batch; ++i) {
init_sampler(&sampler, lctx); init_sampler(&sampler, lctx);