Refactor code structure in llama.cpp and llama.h

This commit is contained in:
Thomas Antony 2023-03-12 18:48:48 -07:00
parent 0995df5c9e
commit b14486e1c0
2 changed files with 619 additions and 276 deletions

818
llama.cpp
View file

@ -10,11 +10,11 @@
#include <string>
#include <vector>
#include <regex>
#include <optional>
// TODO: move somewhere else
#define QK 32
// determine number of model parts based on the dimension
static const std::map<int, int> LLAMA_N_PARTS = {
{ 4096, 1 },
@ -23,6 +23,96 @@ static const std::map<int, int> LLAMA_N_PARTS = {
{ 8192, 8 },
};
struct llama_layer {
// normalization
struct ggml_tensor * attention_norm;
// attention
struct ggml_tensor * wq;
struct ggml_tensor * wk;
struct ggml_tensor * wv;
struct ggml_tensor * wo;
// normalization
struct ggml_tensor * ffn_norm;
// ff
struct ggml_tensor * w1;
struct ggml_tensor * w2;
struct ggml_tensor * w3;
};
struct llama_model {
llama_hparams hparams;
struct ggml_tensor * tok_embeddings;
struct ggml_tensor * norm;
struct ggml_tensor * output;
std::vector<llama_layer> layers;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
};
struct llama_state
{
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
std::vector<float> logits;
mutable std::mt19937 rng;
std::vector<gpt_vocab::id> embd{};
int input_consumed = 0;
std::vector<gpt_vocab::id> embd_inp;
std::vector<gpt_vocab::id> last_n_tokens;
int remaining_tokens = 0;
int n_past = 0;
size_t mem_per_token = 0;
bool is_initialized = false;
llama_state() {}
bool has_more_input() const {
return input_consumed < embd_inp.size();
}
};
struct llama_context
{
int64_t t_load_us = 0;
int64_t t_start_us = 0;
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
llama_model model{};
gpt_vocab vocab{};
gpt_params params{};
std::unique_ptr<llama_state> state = nullptr;
// default constructor
llama_context() = default;
// constructor
llama_context(llama_model&& model, gpt_vocab&& vocab, const gpt_params& params):
t_load_us(0),
t_start_us(0),
wtype(ggml_type::GGML_TYPE_F16),
model(std::move(model)),
vocab(std::move(vocab)),
params(params),
state(std::make_unique<llama_state>())
{
}
~llama_context(){
ggml_free(model.ctx);
}
};
// load the model's weights from a file
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
@ -450,241 +540,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
return true;
}
// evaluate the transformer
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
// The GPT-J model requires about 16MB of memory per input token.
//
bool llama_eval(
const llama_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_embd/hparams.n_head;
const int d_key = n_embd/n_head;
// TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
// static size_t buf_size = hparams.n_ctx*1024*1024;
static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
}
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
struct ggml_tensor * cur;
// norm
{
cur = ggml_rms_norm(ctx0, inpL);
// cur = attention_norm*cur
cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
cur);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_rope(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
n_past, n_rot, 0),
0, 2, 1, 3);
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_rope(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
n_embd/n_head, n_head, n_past + N),
n_past, n_rot, 1),
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3);
// KQV = transpose(V) * KQ_soft_max
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
// projection (no bias)
cur = ggml_mul_mat(ctx0,
model.layers[il].wo,
cur);
}
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
// feed-forward network
{
// norm
{
cur = ggml_rms_norm(ctx0, inpFF);
// cur = ffn_norm*cur
cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
cur);
}
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
model.layers[il].w3,
cur);
cur = ggml_mul_mat(ctx0,
model.layers[il].w1,
cur);
// SILU activation
cur = ggml_silu(ctx0, cur);
cur = ggml_mul(ctx0, cur, tmp);
cur = ggml_mul_mat(ctx0,
model.layers[il].w2,
cur);
}
cur = ggml_add(ctx0, cur, inpFF);
// input for next layer
inpL = cur;
}
// norm
{
inpL = ggml_rms_norm(ctx0, inpL);
// inpL = norm*inpL
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model.norm, inpL),
inpL);
}
// lm_head
{
inpL = ggml_mul_mat(ctx0, model.output, inpL);
}
// logits -> probs
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
}
bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) {
ggml_type type = GGML_TYPE_Q4_1;
@ -940,3 +795,494 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
return true;
}
/* External API */
const std::vector<gpt_vocab::id>& llama_context_get_embd(const llama_context& ctx) {
return ctx.state->embd;
}
gpt_vocab& llama_context_get_vocab(llama_context& ctx) {
return ctx.vocab;
}
bool llama_context_not_finished(const llama_context& ctx)
{
return ctx.state->remaining_tokens > 0;
}
const std::vector<gpt_vocab::id> llama_tokenize_text(const llama_context& ctx, const std::string& text) {
return llama_tokenize(ctx.vocab, text, true);
}
const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens(const llama_context& ctx) {
return ctx.state->last_n_tokens;
}
llama_context* llama_init_from_params(const gpt_params& params) {
llama_model model{};
gpt_vocab vocab{};
// Compute time taken to load model
const int64_t t_start = ggml_time_us();
bool ret = llama_model_load(params.model, model, vocab, 1024);
const int64_t t_end = ggml_time_us();
if(!ret)
{
return nullptr;
}
llama_context* ctx = new llama_context(std::move(model), std::move(vocab), params);
ctx->t_load_us = t_end - t_start;
return ctx;
}
void llama_free_context(llama_context* ctx) {
delete ctx;
}
const char * llama_print_system_info(void) {
static std::string s;
s = "";
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
return s.c_str();
}
void llama_print_context_info(const llama_context& ctx)
{
const gpt_params& params = ctx.params;
const std::vector<gpt_vocab::id>& embd_inp = ctx.state->embd_inp;
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], ctx.vocab.id_to_token.at(embd_inp[i]).c_str());
}
fprintf(stderr, "\n");
}
void llama_print_end_stats(const llama_context& ctx)
{
const llama_state& state = *ctx.state;
fprintf(stderr, "\n\n");
fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, state.mem_per_token);
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx.t_load_us/1000.0f);
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, state.t_sample_us/1000.0f);
fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, state.t_predict_us/1000.0f, state.t_predict_us/1000.0f/state.n_past);
}
// evaluate the transformer
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
// The GPT-J model requires about 16MB of memory per input token.
//
bool llama_eval(
const llama_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_embd/hparams.n_head;
const int d_key = n_embd/n_head;
// TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
// static size_t buf_size = hparams.n_ctx*1024*1024;
static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
}
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
struct ggml_tensor * cur;
// norm
{
cur = ggml_rms_norm(ctx0, inpL);
// cur = attention_norm*cur
cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
cur);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_rope(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
n_past, n_rot, 0),
0, 2, 1, 3);
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_rope(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
n_embd/n_head, n_head, n_past + N),
n_past, n_rot, 1),
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3);
// KQV = transpose(V) * KQ_soft_max
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
// projection (no bias)
cur = ggml_mul_mat(ctx0,
model.layers[il].wo,
cur);
}
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
// feed-forward network
{
// norm
{
cur = ggml_rms_norm(ctx0, inpFF);
// cur = ffn_norm*cur
cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
cur);
}
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
model.layers[il].w3,
cur);
cur = ggml_mul_mat(ctx0,
model.layers[il].w1,
cur);
// SILU activation
cur = ggml_silu(ctx0, cur);
cur = ggml_mul(ctx0, cur, tmp);
cur = ggml_mul_mat(ctx0,
model.layers[il].w2,
cur);
}
cur = ggml_add(ctx0, cur, inpFF);
// input for next layer
inpL = cur;
}
// norm
{
inpL = ggml_rms_norm(ctx0, inpL);
// inpL = norm*inpL
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model.norm, inpL),
inpL);
}
// lm_head
{
inpL = ggml_mul_mat(ctx0, model.output, inpL);
}
// logits -> probs
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
}
bool llama_init_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing) {
llama_state& state = *ctx.state;
llama_model& model = ctx.model;
const gpt_params& params = ctx.params;
if (clear_existing) {
state.embd.clear();
state.input_consumed = 0;
state.embd_inp.clear();
state.last_n_tokens.clear();
state.remaining_tokens = 0;
state.n_past = 0;
}
std::vector<gpt_vocab::id> line_inp = llama_tokenize_text(ctx, text);
state.embd_inp.insert(state.embd_inp.end(), line_inp.begin(), line_inp.end());
int n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) state.embd_inp.size());
state.remaining_tokens = n_predict;
// determine the required inference memory per token:
state.mem_per_token = 0;
if(!llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, state.logits, state.mem_per_token))
{
fprintf(stderr, "Failed to predict with initial prompt\n");
return false;
}
int last_n_size = params.repeat_last_n;
state.last_n_tokens = std::vector<gpt_vocab::id>(last_n_size);
std::fill(state.last_n_tokens.begin(), state.last_n_tokens.end(), 0);
state.is_initialized = true;
return true;
}
/// @brief Injests a batch of input tokens into the context
/// @param ctx
void llama_injest_input_batch(llama_context& ctx)
{
llama_state& state = *ctx.state;
const gpt_params& params = ctx.params;
// Copy at most n_batch elements from embd_inp to embd
size_t num_copied = std::min((size_t) params.n_batch, state.embd_inp.size() - state.input_consumed);
std::copy(state.embd_inp.begin() + state.input_consumed,
state.embd_inp.begin() + state.input_consumed + num_copied,
std::back_inserter(state.embd));
state.input_consumed += num_copied;
// Copy the last `repeat_last_n` elements copied into embd to last_n_tokens
size_t num_copied_last_n = std::min(num_copied, (size_t) params.repeat_last_n);
state.last_n_tokens.erase(state.last_n_tokens.begin(), state.last_n_tokens.begin()+num_copied_last_n);
state.last_n_tokens.insert(state.last_n_tokens.end(), state.embd.end() - num_copied_last_n, state.embd.end());
}
/// @brief Run the prediction step on ctx.embd and store result in ctx.state.logits
/// @param ctx
/// @return
bool llama_predict(llama_context& ctx){
const int64_t t_start_us = ggml_time_us();
llama_state& state = *ctx.state;
llama_model& model = ctx.model;
const gpt_params& params = ctx.params;
if (!llama_eval(model, params.n_threads, state.n_past, state.embd, state.logits, state.mem_per_token)) {
fprintf(stderr, "Failed to predict\n");
return false;
}
state.t_predict_us += ggml_time_us() - t_start_us;
return true;
}
/// @brief Sample a token from the logits
/// @param ctx
/// @return token id
gpt_vocab::id llama_sample_token(llama_context& ctx)
{
llama_state& state = *ctx.state;
llama_model& model = ctx.model;
const gpt_params& params = ctx.params;
const float top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;
const int n_vocab = model.hparams.n_vocab;
gpt_vocab::id id = 0;
{
const int64_t t_start_sample_us = ggml_time_us();
id = llama_sample_top_p_top_k(ctx.vocab, state.logits.data() + (state.logits.size() - n_vocab),
state.last_n_tokens, repeat_penalty, top_k, top_p, temp, state.rng);
state.last_n_tokens.erase(state.last_n_tokens.begin());
state.last_n_tokens.push_back(id);
state.t_sample_us += ggml_time_us() - t_start_sample_us;
}
return id;
}
/// @brief Injest all input (in multiple batches) into model and run call predict()
/// @param ctx
bool llama_injest_input(llama_context& ctx, const std::string& text, bool clear_existing)
{
llama_state& state = *ctx.state;
// Initialize context, tokenize text and clear existing state if necessary
if(!state.is_initialized && !llama_init_context_with_prompt(ctx, text, clear_existing))
{
return false;
}
// Injest the tokens into the model one batch at a time
while (state.has_more_input())
{
llama_injest_input_batch(ctx);
if (state.embd.size() >= 0) {
if(!llama_predict(ctx))
{
return false;
};
}
state.n_past += state.embd.size();
state.embd.clear();
}
return true;
}
bool llama_inference(llama_context& ctx, gpt_vocab::id& id) {
llama_state& state = *ctx.state;
// Tokenize text if we are starting out
if(!state.is_initialized)
{
fprintf(stderr, "State must be initialized before running inference");
return false;
}
// No more tokens to generate
if (state.remaining_tokens <= 0) {
return false;
}
// Do prediction if we have enough tokens
if (state.embd.size() > 0) {
if(!llama_predict(ctx))
{
return false;
}
}
// sample a token
id = llama_sample_token(ctx);
// add it to the context
state.embd.push_back(id);
state.n_past += 1;
// decrement remaining sampling budget
--state.remaining_tokens;
// end of text token
if (state.embd.back() == 2) {
state.remaining_tokens = 0;
}
return true;
}

77
llama.h
View file

@ -3,8 +3,27 @@
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "utils.h"
#include "ggml.h"
#ifdef LLAMA_SHARED
# ifdef _WIN32
# ifdef LLAMA_BUILD
# define LLAMA_API __declspec(dllexport)
# else
# define LLAMA_API __declspec(dllimport)
# endif
# else
# define LLAMA_API __attribute__ ((visibility ("default")))
# endif
#else
# define LLAMA_API
#endif
// default hparams (LLaMA 7B)
struct llama_hparams {
@ -18,50 +37,28 @@ struct llama_hparams {
int32_t f16 = 1;
};
struct llama_layer {
// normalization
struct ggml_tensor * attention_norm;
struct llama_context;
// attention
struct ggml_tensor * wq;
struct ggml_tensor * wk;
struct ggml_tensor * wv;
struct ggml_tensor * wo;
void llama_free_context(llama_context* ctx);
// normalization
struct ggml_tensor * ffn_norm;
const std::vector<gpt_vocab::id>& llama_context_get_embd(const llama_context& ctx);
gpt_vocab& llama_context_get_vocab(llama_context& ctx);
bool llama_context_not_finished(const llama_context& ctx);
const std::vector<gpt_vocab::id> llama_tokenize_text(const llama_context& ctx, const std::string& text);
// ff
struct ggml_tensor * w1;
struct ggml_tensor * w2;
struct ggml_tensor * w3;
};
const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens(const llama_context& ctx);
bool llama_init_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing = true);
struct llama_model {
llama_hparams hparams;
// Various functions for loading a ggml LLaMA model.
llama_context* llama_init_from_params(const gpt_params& params);
struct ggml_tensor * tok_embeddings;
// Run inference on a LLaMA model using llama_context.
std::vector<float> llama_eval(llama_context& ctx, const gpt_params& params, std::string& text);
struct ggml_tensor * norm;
struct ggml_tensor * output;
std::vector<llama_layer> layers;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
};
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx);
bool llama_eval(
const llama_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token);
bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype);
bool llama_injest_input(llama_context& ctx, const std::string& text, bool clear_existing = true);
bool llama_inference(llama_context& ctx, gpt_vocab::id& model_output);
void llama_print_context_info(const llama_context& ctx);
void llama_print_end_stats(const llama_context& ctx);