diff --git a/llama.cpp b/llama.cpp index 772600b4c..70f763fd5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10,11 +10,11 @@ #include #include #include +#include // TODO: move somewhere else #define QK 32 - // determine number of model parts based on the dimension static const std::map LLAMA_N_PARTS = { { 4096, 1 }, @@ -23,6 +23,96 @@ static const std::map LLAMA_N_PARTS = { { 8192, 8 }, }; +struct llama_layer { + // normalization + struct ggml_tensor * attention_norm; + + // attention + struct ggml_tensor * wq; + struct ggml_tensor * wk; + struct ggml_tensor * wv; + struct ggml_tensor * wo; + + // normalization + struct ggml_tensor * ffn_norm; + + // ff + struct ggml_tensor * w1; + struct ggml_tensor * w2; + struct ggml_tensor * w3; +}; +struct llama_model { + llama_hparams hparams; + + struct ggml_tensor * tok_embeddings; + + struct ggml_tensor * norm; + struct ggml_tensor * output; + + std::vector layers; + + // key + value memory + struct ggml_tensor * memory_k; + struct ggml_tensor * memory_v; + + // + struct ggml_context * ctx; + std::map tensors; +}; +struct llama_state +{ + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + mutable std::mt19937 rng; + + std::vector embd{}; + + int input_consumed = 0; + std::vector embd_inp; + std::vector last_n_tokens; + int remaining_tokens = 0; + int n_past = 0; + size_t mem_per_token = 0; + bool is_initialized = false; + llama_state() {} + + bool has_more_input() const { + return input_consumed < embd_inp.size(); + } +}; +struct llama_context +{ + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16) + + llama_model model{}; + gpt_vocab vocab{}; + gpt_params params{}; + + std::unique_ptr state = nullptr; + + // default constructor + llama_context() = default; + // constructor + llama_context(llama_model&& model, gpt_vocab&& vocab, const gpt_params& params): + t_load_us(0), + t_start_us(0), + wtype(ggml_type::GGML_TYPE_F16), + model(std::move(model)), + vocab(std::move(vocab)), + params(params), + state(std::make_unique()) + { + } + ~llama_context(){ + ggml_free(model.ctx); + } +}; // load the model's weights from a file bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) { @@ -450,241 +540,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab return true; } -// evaluate the transformer -// -// - model: the model -// - n_threads: number of threads to use -// - n_past: the context size so far -// - embd_inp: the embeddings of the tokens in the context -// - embd_w: the predicted logits for the next token -// -// The GPT-J model requires about 16MB of memory per input token. -// -bool llama_eval( - const llama_model & model, - const int n_threads, - const int n_past, - const std::vector & embd_inp, - std::vector & embd_w, - size_t & mem_per_token) { - const int N = embd_inp.size(); - - const auto & hparams = model.hparams; - - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - const int n_head = hparams.n_head; - const int n_vocab = hparams.n_vocab; - const int n_rot = hparams.n_embd/hparams.n_head; - - const int d_key = n_embd/n_head; - - // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case - // static size_t buf_size = hparams.n_ctx*1024*1024; - static size_t buf_size = 512u*1024*1024; - static void * buf = malloc(buf_size); - - if (mem_per_token > 0 && mem_per_token*N > buf_size) { - const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead - //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); - - // reallocate - buf_size = buf_size_new; - buf = realloc(buf, buf_size); - if (buf == nullptr) { - fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); - return false; - } - } - - struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, - }; - - struct ggml_context * ctx0 = ggml_init(params); - ggml_cgraph gf = {}; - gf.n_threads = n_threads; - - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); - - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - struct ggml_tensor * cur; - - // norm - { - cur = ggml_rms_norm(ctx0, inpL); - - // cur = attention_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].attention_norm, cur), - cur); - } - - // self-attention - { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - - // store key and value to memory - if (N >= 1) { - struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); - - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); - } - - // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_rope(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)), - n_past, n_rot, 0), - 0, 2, 1, 3); - - // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_rope(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), - n_embd/n_head, n_head, n_past + N), - n_past, n_rot, 1), - 0, 2, 1, 3); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - struct ggml_tensor * KQ_scaled = - ggml_scale(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) - ); - - // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - - // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() - struct ggml_tensor * V_trans = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), - n_embd/n_head, n_head, n_past + N), - 1, 2, 0, 3); - - // KQV = transpose(V) * KQ_soft_max - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - - // projection (no bias) - cur = ggml_mul_mat(ctx0, - model.layers[il].wo, - cur); - } - - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - - // feed-forward network - { - // norm - { - cur = ggml_rms_norm(ctx0, inpFF); - - // cur = ffn_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].ffn_norm, cur), - cur); - } - - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); - - - cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); - - // SILU activation - cur = ggml_silu(ctx0, cur); - - cur = ggml_mul(ctx0, cur, tmp); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); - } - - cur = ggml_add(ctx0, cur, inpFF); - - // input for next layer - inpL = cur; - } - - // norm - { - inpL = ggml_rms_norm(ctx0, inpL); - - // inpL = norm*inpL - inpL = ggml_mul(ctx0, - ggml_repeat(ctx0, model.norm, inpL), - inpL); - } - - // lm_head - { - inpL = ggml_mul_mat(ctx0, model.output, inpL); - } - - // logits -> probs - //inpL = ggml_soft_max(ctx0, inpL); - - // run the computation - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute (ctx0, &gf); - - //if (n_past%100 == 0) { - // ggml_graph_print (&gf); - // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); - //} - - //embd_w.resize(n_vocab*N); - //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - - // return result for just the last token - embd_w.resize(n_vocab); - memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); - - if (mem_per_token == 0) { - mem_per_token = ggml_used_mem(ctx0)/N; - } - //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0)); - - ggml_free(ctx0); - - return true; -} bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) { ggml_type type = GGML_TYPE_Q4_1; @@ -940,3 +795,494 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna return true; } + + +/* External API */ + +const std::vector& llama_context_get_embd(const llama_context& ctx) { + return ctx.state->embd; +} +gpt_vocab& llama_context_get_vocab(llama_context& ctx) { + return ctx.vocab; +} +bool llama_context_not_finished(const llama_context& ctx) +{ + return ctx.state->remaining_tokens > 0; +} +const std::vector llama_tokenize_text(const llama_context& ctx, const std::string& text) { + return llama_tokenize(ctx.vocab, text, true); +} +const std::vector& llama_context_get_last_n_tokens(const llama_context& ctx) { + return ctx.state->last_n_tokens; +} +llama_context* llama_init_from_params(const gpt_params& params) { + llama_model model{}; + gpt_vocab vocab{}; + + // Compute time taken to load model + const int64_t t_start = ggml_time_us(); + bool ret = llama_model_load(params.model, model, vocab, 1024); + const int64_t t_end = ggml_time_us(); + if(!ret) + { + return nullptr; + } + llama_context* ctx = new llama_context(std::move(model), std::move(vocab), params); + ctx->t_load_us = t_end - t_start; + return ctx; +} +void llama_free_context(llama_context* ctx) { + delete ctx; +} + +const char * llama_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + + return s.c_str(); +} + +void llama_print_context_info(const llama_context& ctx) +{ + const gpt_params& params = ctx.params; + const std::vector& embd_inp = ctx.state->embd_inp; + { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + } + fprintf(stderr, "\n"); + fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + for (int i = 0; i < (int) embd_inp.size(); i++) { + fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], ctx.vocab.id_to_token.at(embd_inp[i]).c_str()); + } + fprintf(stderr, "\n"); +} + +void llama_print_end_stats(const llama_context& ctx) +{ + const llama_state& state = *ctx.state; + fprintf(stderr, "\n\n"); + fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, state.mem_per_token); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx.t_load_us/1000.0f); + fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, state.t_sample_us/1000.0f); + fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, state.t_predict_us/1000.0f, state.t_predict_us/1000.0f/state.n_past); +} +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +// The GPT-J model requires about 16MB of memory per input token. +// +bool llama_eval( + const llama_model & model, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w, + size_t & mem_per_token) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + const int n_rot = hparams.n_embd/hparams.n_head; + + const int d_key = n_embd/n_head; + + // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case + // static size_t buf_size = hparams.n_ctx*1024*1024; + static size_t buf_size = 512u*1024*1024; + static void * buf = malloc(buf_size); + + if (mem_per_token > 0 && mem_per_token*N > buf_size) { + const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead + //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); + + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + struct ggml_tensor * cur; + + // norm + { + cur = ggml_rms_norm(ctx0, inpL); + + // cur = attention_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].attention_norm, cur), + cur); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + + // store key and value to memory + if (N >= 1) { + struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_rope(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)), + n_past, n_rot, 0), + 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_rope(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), + n_embd/n_head, n_head, n_past + N), + n_past, n_rot, 1), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) + ); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ggml_tensor * V_trans = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + n_embd/n_head, n_head, n_past + N), + 1, 2, 0, 3); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model.layers[il].wo, + cur); + } + + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); + + // feed-forward network + { + // norm + { + cur = ggml_rms_norm(ctx0, inpFF); + + // cur = ffn_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].ffn_norm, cur), + cur); + } + + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model.layers[il].w3, + cur); + + + cur = ggml_mul_mat(ctx0, + model.layers[il].w1, + cur); + + // SILU activation + cur = ggml_silu(ctx0, cur); + + cur = ggml_mul(ctx0, cur, tmp); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w2, + cur); + } + + cur = ggml_add(ctx0, cur, inpFF); + + // input for next layer + inpL = cur; + } + + // norm + { + inpL = ggml_rms_norm(ctx0, inpL); + + // inpL = norm*inpL + inpL = ggml_mul(ctx0, + ggml_repeat(ctx0, model.norm, inpL), + inpL); + } + + // lm_head + { + inpL = ggml_mul_mat(ctx0, model.output, inpL); + } + + // logits -> probs + //inpL = ggml_soft_max(ctx0, inpL); + + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute (ctx0, &gf); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + //embd_w.resize(n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; + } + //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + + return true; +} + +bool llama_init_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing) { + llama_state& state = *ctx.state; + llama_model& model = ctx.model; + const gpt_params& params = ctx.params; + + if (clear_existing) { + state.embd.clear(); + state.input_consumed = 0; + state.embd_inp.clear(); + state.last_n_tokens.clear(); + state.remaining_tokens = 0; + state.n_past = 0; + } + + std::vector line_inp = llama_tokenize_text(ctx, text); + state.embd_inp.insert(state.embd_inp.end(), line_inp.begin(), line_inp.end()); + + int n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) state.embd_inp.size()); + state.remaining_tokens = n_predict; + + // determine the required inference memory per token: + state.mem_per_token = 0; + if(!llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, state.logits, state.mem_per_token)) + { + fprintf(stderr, "Failed to predict with initial prompt\n"); + return false; + } + + int last_n_size = params.repeat_last_n; + state.last_n_tokens = std::vector(last_n_size); + std::fill(state.last_n_tokens.begin(), state.last_n_tokens.end(), 0); + + state.is_initialized = true; + return true; +} + +/// @brief Injests a batch of input tokens into the context +/// @param ctx +void llama_injest_input_batch(llama_context& ctx) +{ + llama_state& state = *ctx.state; + const gpt_params& params = ctx.params; + + // Copy at most n_batch elements from embd_inp to embd + size_t num_copied = std::min((size_t) params.n_batch, state.embd_inp.size() - state.input_consumed); + std::copy(state.embd_inp.begin() + state.input_consumed, + state.embd_inp.begin() + state.input_consumed + num_copied, + std::back_inserter(state.embd)); + state.input_consumed += num_copied; + + // Copy the last `repeat_last_n` elements copied into embd to last_n_tokens + size_t num_copied_last_n = std::min(num_copied, (size_t) params.repeat_last_n); + state.last_n_tokens.erase(state.last_n_tokens.begin(), state.last_n_tokens.begin()+num_copied_last_n); + state.last_n_tokens.insert(state.last_n_tokens.end(), state.embd.end() - num_copied_last_n, state.embd.end()); +} + +/// @brief Run the prediction step on ctx.embd and store result in ctx.state.logits +/// @param ctx +/// @return +bool llama_predict(llama_context& ctx){ + const int64_t t_start_us = ggml_time_us(); + llama_state& state = *ctx.state; + llama_model& model = ctx.model; + const gpt_params& params = ctx.params; + + if (!llama_eval(model, params.n_threads, state.n_past, state.embd, state.logits, state.mem_per_token)) { + fprintf(stderr, "Failed to predict\n"); + return false; + } + + state.t_predict_us += ggml_time_us() - t_start_us; + return true; +} +/// @brief Sample a token from the logits +/// @param ctx +/// @return token id +gpt_vocab::id llama_sample_token(llama_context& ctx) +{ + llama_state& state = *ctx.state; + llama_model& model = ctx.model; + const gpt_params& params = ctx.params; + + const float top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + const float repeat_penalty = params.repeat_penalty; + + const int n_vocab = model.hparams.n_vocab; + + gpt_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_time_us(); + + id = llama_sample_top_p_top_k(ctx.vocab, state.logits.data() + (state.logits.size() - n_vocab), + state.last_n_tokens, repeat_penalty, top_k, top_p, temp, state.rng); + + state.last_n_tokens.erase(state.last_n_tokens.begin()); + state.last_n_tokens.push_back(id); + + state.t_sample_us += ggml_time_us() - t_start_sample_us; + } + return id; +} +/// @brief Injest all input (in multiple batches) into model and run call predict() +/// @param ctx +bool llama_injest_input(llama_context& ctx, const std::string& text, bool clear_existing) +{ + llama_state& state = *ctx.state; + + // Initialize context, tokenize text and clear existing state if necessary + if(!state.is_initialized && !llama_init_context_with_prompt(ctx, text, clear_existing)) + { + return false; + } + + // Injest the tokens into the model one batch at a time + while (state.has_more_input()) + { + llama_injest_input_batch(ctx); + if (state.embd.size() >= 0) { + if(!llama_predict(ctx)) + { + return false; + }; + } + state.n_past += state.embd.size(); + state.embd.clear(); + } + return true; +} +bool llama_inference(llama_context& ctx, gpt_vocab::id& id) { + llama_state& state = *ctx.state; + + // Tokenize text if we are starting out + if(!state.is_initialized) + { + fprintf(stderr, "State must be initialized before running inference"); + return false; + } + + // No more tokens to generate + if (state.remaining_tokens <= 0) { + return false; + } + + // Do prediction if we have enough tokens + if (state.embd.size() > 0) { + if(!llama_predict(ctx)) + { + return false; + } + } + // sample a token + id = llama_sample_token(ctx); + // add it to the context + state.embd.push_back(id); + + state.n_past += 1; + // decrement remaining sampling budget + --state.remaining_tokens; + + // end of text token + if (state.embd.back() == 2) { + state.remaining_tokens = 0; + } + return true; +} diff --git a/llama.h b/llama.h index a90227304..387efa686 100644 --- a/llama.h +++ b/llama.h @@ -3,8 +3,27 @@ #include #include #include +#include #include "utils.h" +#include "ggml.h" + +#ifdef LLAMA_SHARED +# ifdef _WIN32 +# ifdef LLAMA_BUILD +# define LLAMA_API __declspec(dllexport) +# else +# define LLAMA_API __declspec(dllimport) +# endif +# else +# define LLAMA_API __attribute__ ((visibility ("default"))) +# endif +#else +# define LLAMA_API +#endif + + + // default hparams (LLaMA 7B) struct llama_hparams { @@ -18,50 +37,28 @@ struct llama_hparams { int32_t f16 = 1; }; -struct llama_layer { - // normalization - struct ggml_tensor * attention_norm; +struct llama_context; - // attention - struct ggml_tensor * wq; - struct ggml_tensor * wk; - struct ggml_tensor * wv; - struct ggml_tensor * wo; +void llama_free_context(llama_context* ctx); - // normalization - struct ggml_tensor * ffn_norm; +const std::vector& llama_context_get_embd(const llama_context& ctx); +gpt_vocab& llama_context_get_vocab(llama_context& ctx); +bool llama_context_not_finished(const llama_context& ctx); +const std::vector llama_tokenize_text(const llama_context& ctx, const std::string& text); - // ff - struct ggml_tensor * w1; - struct ggml_tensor * w2; - struct ggml_tensor * w3; -}; +const std::vector& llama_context_get_last_n_tokens(const llama_context& ctx); +bool llama_init_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing = true); -struct llama_model { - llama_hparams hparams; +// Various functions for loading a ggml LLaMA model. +llama_context* llama_init_from_params(const gpt_params& params); - struct ggml_tensor * tok_embeddings; +// Run inference on a LLaMA model using llama_context. +std::vector llama_eval(llama_context& ctx, const gpt_params& params, std::string& text); - struct ggml_tensor * norm; - struct ggml_tensor * output; - - std::vector layers; - - // key + value memory - struct ggml_tensor * memory_k; - struct ggml_tensor * memory_v; - - // - struct ggml_context * ctx; - std::map tensors; -}; - -bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx); -bool llama_eval( - const llama_model & model, - const int n_threads, - const int n_past, - const std::vector & embd_inp, - std::vector & embd_w, - size_t & mem_per_token); bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype); + +bool llama_injest_input(llama_context& ctx, const std::string& text, bool clear_existing = true); + +bool llama_inference(llama_context& ctx, gpt_vocab::id& model_output); +void llama_print_context_info(const llama_context& ctx); +void llama_print_end_stats(const llama_context& ctx);