diff --git a/examples/baby-llama/baby-llama-text.cpp b/examples/baby-llama/baby-llama-text.cpp index 267f44321..418cc5fff 100644 --- a/examples/baby-llama/baby-llama-text.cpp +++ b/examples/baby-llama/baby-llama-text.cpp @@ -1168,6 +1168,239 @@ struct ggml_tensor * forward_batch_wo_cache( return inpL; } +struct ggml_tensor * forward_batch_wo_cache_flash_attn( + struct my_llama_model * model, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * tokens_input, + const int n_tokens, + const int n_batch) { + + const int n_past = 0; + const int N = n_tokens; + + const auto & hparams = model->hparams; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_head = hparams.n_head; + const int n_rot = hparams.n_rot; + const int n_ff = get_n_ff(&hparams); + + struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); + memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); + + // inpL shape [n_embd,N*n_batch,1] + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); + assert_shape_2d(inpL, n_embd, N*n_batch); + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + struct ggml_tensor * cur; + + // lctx.use_buf(ctx0, 0); + + // norm + { + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = attention_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].attention_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // self-attention + { + // compute Q and K and RoPE them + // wq shape [n_embd, n_embd, 1, 1] + // wk shape [n_embd, n_embd, 1, 1] + // Qcur shape [n_embd/n_head, n_head, N, n_batch] + // Kcur shape [n_embd/n_head, n_head, N, n_batch] + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); + assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); + + // Vcur shape [N, n_batch, n_embd/n_head, n_head] + struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head); + assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head); + + // Qcur shape [n_embd/n_head, n_head, N, n_batch] + // Q shape [n_embd/n_head, N, n_head, n_batch] + struct ggml_tensor * Q = + ggml_permute(ctx0, + Qcur, + 0, 2, 1, 3); + assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); + + // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] + // K shape [n_embd/n_head, N, n_head, n_batch] + struct ggml_tensor * K = + ggml_permute(ctx0, + Kcur, + 0, 2, 1, 3); + assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch); + + // // K * Q + // // KQ shape [N, N, n_head, n_batch] + // struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + // assert_shape_4d(KQ, N, N, n_head, n_batch); + + // // KQ_scaled = KQ / sqrt(n_embd/n_head) + // // KQ_scaled shape [N, N, n_head, n_batch] + // struct ggml_tensor * KQ_scaled = + // ggml_scale_inplace(ctx0, + // KQ, + // ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); + // assert_shape_4d(KQ_scaled, N, N, n_head, n_batch); + + // // KQ_masked = mask_past(KQ_scaled) + // // KQ_masked shape [N, N, n_head, n_batch] + // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + // assert_shape_4d(KQ_masked, N, N, n_head, n_batch); + + // // KQ = soft_max(KQ_masked) + // // KQ_soft_max shape [N, N, n_head, n_batch] + // struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + // assert_shape_4d(KQ_soft_max, N, N, n_head, n_batch); + + // Vcur shape [N, n_batch, n_embd/n_head, n_head] + // V shape [N, n_embd/n_head, n_head, n_batch] + struct ggml_tensor * V = + ggml_permute(ctx0, + Vcur, + 0, 3, 1, 2); + assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch); + + // // KQV shape [n_embd/n_head, N, n_head, n_batch] + // struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + // assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); + + + bool masked = true; + struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, masked); + assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + // KQV_merged shape [n_embd/n_head, n_head, N, n_batch] + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); + // KQV_merged shape + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); + assert_shape_2d(cur, n_embd, N*n_batch); + + // projection (no bias) + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].wo, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // lctx.use_buf(ctx0, 1); + + // inpFF shape [n_embd,N*n_batch,1,1] + struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA); + assert_shape_2d(inpFF, n_embd, N*n_batch); + + // feed-forward network + { + // norm + { + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_rms_norm(ctx0, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = ffn_norm*cur + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // tmp shape [n_ff,N*n_batch,1,1] + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model->layers[il].w3, + cur); + assert_shape_2d(tmp, n_ff, N*n_batch); + + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w1, + cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + // SILU activation + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_silu(ctx0, cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_mul(ctx0, cur, tmp); + assert_shape_2d(cur, n_ff, N*n_batch); + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w2, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_add_inplace(ctx0, cur, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // input for next layer + // inpL shape [n_embd,N*n_batch,1,1] + inpL = cur; + assert_shape_2d(inpL, n_embd, N*n_batch); + } + + // norm + { + + // inpL shape [n_embd,N*n_batch,1,1] + inpL = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(inpL, n_embd, N*n_batch); + + // inpL = norm*inpL + // inpL shape [n_embd,N*n_batch,1,1] + inpL = ggml_mul(ctx0, + ggml_repeat(ctx0, model->norm, inpL), + inpL); + + assert_shape_2d(inpL, n_embd, N*n_batch); + + //embeddings = inpL; + } + + // lm_head + // inpL shape [n_vocab,N*n_batch,1,1] + inpL = ggml_mul_mat(ctx0, model->output, inpL); + assert_shape_2d(inpL, n_vocab, N*n_batch); + + { + // inpL shape [n_vocab,N,n_batch,1] + inpL = ggml_reshape_3d(ctx0, + inpL, + n_vocab, N, n_batch); + assert_shape_3d(inpL, n_vocab, N, n_batch); + } + + // run the computation + ggml_build_forward_expand(gf, inpL); + + return inpL; +} + void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) { float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); *ptr = value; @@ -1644,7 +1877,7 @@ void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) { } std::string name = file->read_string(name_len); - GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)) == 0); + GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)-1) == 0); file->seek(-file->tell() & 31, SEEK_CUR); file->read_raw(tensor->data, ggml_nbytes(tensor)); @@ -1930,7 +2163,42 @@ int main(int argc, char ** argv) { //return 1; } - srand(time(NULL)); + int seed = 1; + int n_ctx = 256; + // int n_ctx = 64; + int n_embd = 256; + int n_mult = 256; + int n_head = 8; + int n_layer = 16; + int n_rotmax = 64; + + int n_threads = 6; + int n_batch = 8; + int n_examples = 32; + + int print_info_interval = 1; + int print_details_interval = 2; + + bool samples_start_after_nl = false; + bool use_adam = true; + bool use_flash = false; + + // only adam + int warmup = 100; + int cos_decay_steps = 1000; + float cos_decay_restart = 1.1f; + float cos_decay_alpha = 0.0f; + + int lbfgs_n_iter = 16; + int adam_n_iter = 16; + float adam_alpha = 1e-3; + float adam_decay = 1e-3; + + if (seed < 0) { + srand(time(NULL)); + } else { + srand(seed); + } const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1]; const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2]; @@ -1971,12 +2239,12 @@ int main(int argc, char ** argv) { struct my_llama_model model; model.hparams.n_vocab = llama_n_vocab(lctx); - model.hparams.n_ctx = 32; - model.hparams.n_embd = 128; - model.hparams.n_mult = 64; - model.hparams.n_head = 16; - model.hparams.n_layer = 4; - model.hparams.n_rot = std::min(64u, model.hparams.n_embd / model.hparams.n_head); + model.hparams.n_ctx = n_ctx; + model.hparams.n_embd = n_embd; + model.hparams.n_mult = n_mult; + model.hparams.n_head = n_head; + model.hparams.n_layer = n_layer; + model.hparams.n_rot = std::min((uint32_t)n_rotmax, model.hparams.n_embd / model.hparams.n_head); print_params(&model.hparams); @@ -2011,18 +2279,6 @@ int main(int argc, char ** argv) { my_llama_sampler sampler; - int n_threads = 6; - int n_batch = 32; - int n_examples = 32; - - bool samples_start_after_nl = false; - bool use_adam = true; - - int warmup = 100; - int cos_decay_steps = 1000; - float cos_decay_restart = 1.1f; - float cos_decay_alpha = 0.0f; - int n_tokens = model.hparams.n_ctx; int n_vocab = model.hparams.n_vocab; @@ -2035,15 +2291,15 @@ int main(int argc, char ** argv) { opt_params_adam.print_forward_graph = false; opt_params_adam.print_backward_graph = false; opt_params_adam.n_threads = n_threads; - opt_params_adam.adam.n_iter = 16; + opt_params_adam.adam.n_iter = adam_n_iter; opt_params_adam.adam.sched = 1.0f; - opt_params_adam.adam.alpha = 1e-3; - opt_params_adam.adam.decay = 1e-3; + opt_params_adam.adam.alpha = adam_alpha; + opt_params_adam.adam.decay = adam_decay; opt_params_lbfgs.print_forward_graph = false; opt_params_lbfgs.print_backward_graph = false; - opt_params_lbfgs.n_threads = n_threads; - opt_params_lbfgs.lbfgs.n_iter = 16; + opt_params_lbfgs.n_threads = n_threads; + opt_params_lbfgs.lbfgs.n_iter = lbfgs_n_iter; opt->ctx = model.ctx; opt->params = use_adam ? opt_params_adam : opt_params_lbfgs; @@ -2117,7 +2373,9 @@ int main(int argc, char ** argv) { struct ggml_tensor * logits = (n_past == 0) - ? forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch) + ? (use_flash + ? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch) + : forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)) : forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs); @@ -2148,16 +2406,16 @@ int main(int argc, char ** argv) { float error_after_opt = ggml_get_f32_1d(e, 0); - printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt); - printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt); - if (ex % 1 == 0) { + if (ex % print_info_interval == 0) { printf("Example %d, opt iter %d\n", ex, opt->iter); printf("error_before_opt: %.6f\n", error_before_opt); printf("error_after_opt: %.6f\n", error_after_opt); + printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt); + printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt); } - if (ex % 2 == 0) { + if (ex % print_details_interval == 0) { // set_logits_masked(logits, token_notavail, -1e9); for (int i=0; i