From 2852902eda37e8490a07fcd0ab0d803e59260a52 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Wed, 24 Jan 2024 14:55:41 -0500 Subject: [PATCH] test-backend-ops : add llama test --- tests/test-backend-ops.cpp | 212 +++++++++++++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a0063bbb9..b776e493a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1464,6 +1464,216 @@ struct test_moe : public test_case { } }; + +// llama +struct test_llama : public test_case { + const int n_tokens; + static constexpr float f_norm_rms_eps = 1e-5; + static constexpr int64_t n_embd_k_gqa = 3200; + static constexpr int64_t n_embd_v_gqa = 3200; + static constexpr int64_t n_ctx = 512; + static constexpr int64_t n_layer = 1; + static constexpr int64_t n_head = 32; + static constexpr int64_t n_head_kv = 32; + static constexpr int64_t n_embd_head = 100; + static constexpr int64_t n_embd = 3200; + static constexpr int64_t n_orig_ctx = n_ctx; + static constexpr int64_t n_ff = 8640; + static constexpr int64_t n_kv = 32; + static constexpr int64_t kv_head = 1; + static constexpr float freq_base = 10000.0f; + static constexpr float freq_scale = 1.0f; + static constexpr float ext_factor = 0.0f; + static constexpr float attn_factor = 1.0f; + static constexpr float beta_fast = 32.0f; + static constexpr float beta_slow = 1.0f; + + std::string op_desc(ggml_tensor * t) override { + return "LLAMA"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR1(n_tokens); + } + + test_llama(int n_tokens = 1) + : n_tokens(n_tokens) { + } + + struct ggml_tensor * llm_build_norm( + struct ggml_context * ctx, + struct ggml_tensor * cur, + struct ggml_tensor * mw) { + cur = ggml_rms_norm(ctx, cur, f_norm_rms_eps); + cur = ggml_mul(ctx, cur, mw); + return cur; + } + + void llm_build_kv_store( + struct ggml_context * ctx, + struct ggml_tensor * k_l, + struct ggml_tensor * v_l, + struct ggml_tensor * k_cur, + struct ggml_tensor * v_cur) { + // compute the transposed [n_tokens, n_embd] V matrix + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + + struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, n_tokens*n_embd_k_gqa, + (ggml_row_size(k_l->type, n_embd_k_gqa))*kv_head); + + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(v_l), + (kv_head)*ggml_element_size(v_l)); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_cpy(ctx, k_cur, k_cache_view); + ggml_cpy(ctx, v_cur_t, v_cache_view); + } + + // if max_alibi_bias > 0 then apply ALiBi + struct ggml_tensor * llm_build_kqv( + struct ggml_context * ctx, + struct ggml_tensor * k_l, + struct ggml_tensor * v_l, + struct ggml_tensor * q_cur, + struct ggml_tensor * kq_mask, + float kq_scale) { + struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); + + struct ggml_tensor * k = + ggml_view_3d(ctx, k_l, + n_embd_head, n_kv, n_head_kv, + ggml_row_size(k_l->type, n_embd_k_gqa), + ggml_row_size(k_l->type, n_embd_head), + 0); + + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, v_l, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(v_l)*n_ctx, + ggml_element_size(v_l)*n_ctx*n_embd_head, + 0); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + + struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head*n_head, n_tokens); + + struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200); + cur = ggml_mul_mat(ctx, wo, cur); + + return cur; + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + const int64_t n_rot = n_embd_head; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_kv, n_tokens, 1); + + ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); + ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3200); + cur = llm_build_norm(ctx, inpL, attn_norm); + + // self-attention + { + ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200); + ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200); + ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur); + + Qcur = ggml_rope_custom( + ctx, ggml_reshape_3d(ctx, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_custom( + ctx, ggml_reshape_3d(ctx, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur); + + cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(n_embd_head))); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA); + + // feed-forward network + ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3200); + cur = llm_build_norm(ctx, ffn_inp, ffn_norm); + + ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 8640); + ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 8640); + ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 8640, 3200); + struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur); + cur = ggml_mul_mat(ctx, ffn_gate, cur); + cur = ggml_silu(ctx, cur); + cur = ggml_mul(ctx, cur, tmp); + cur = ggml_mul_mat(ctx, ffn_down, cur); + + cur = ggml_add(ctx, cur, ffn_inp); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3200); + cur = llm_build_norm(ctx, cur, output_norm); + + // lm_head + ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 32000); + cur = ggml_mul_mat(ctx, output, cur); + + return cur; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + // pos + std::vector data(n_tokens); + for (int i = 0; i < n_tokens; i++) { + data[i] = rand() % n_ctx; + } + ggml_backend_tensor_set(t, data.data(), 0, n_tokens * sizeof(int)); + } else { + init_tensor_uniform(t); + } + } + } +}; + static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { std::vector> test_cases; std::default_random_engine rng(0); @@ -1651,6 +1861,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336)); #endif + test_cases.emplace_back(new test_llama()); + // run tests if (mode == MODE_TEST) { ggml_backend_t backend_cpu = ggml_backend_cpu_init();