Refactor code structure in llama.cpp and llama.h
This commit is contained in:
parent
0995df5c9e
commit
b14486e1c0
2 changed files with 619 additions and 276 deletions
818
llama.cpp
818
llama.cpp
|
@ -10,11 +10,11 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
#include <optional>
|
||||
|
||||
// TODO: move somewhere else
|
||||
#define QK 32
|
||||
|
||||
|
||||
// determine number of model parts based on the dimension
|
||||
static const std::map<int, int> LLAMA_N_PARTS = {
|
||||
{ 4096, 1 },
|
||||
|
@ -23,6 +23,96 @@ static const std::map<int, int> LLAMA_N_PARTS = {
|
|||
{ 8192, 8 },
|
||||
};
|
||||
|
||||
struct llama_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * attention_norm;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * wq;
|
||||
struct ggml_tensor * wk;
|
||||
struct ggml_tensor * wv;
|
||||
struct ggml_tensor * wo;
|
||||
|
||||
// normalization
|
||||
struct ggml_tensor * ffn_norm;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * w1;
|
||||
struct ggml_tensor * w2;
|
||||
struct ggml_tensor * w3;
|
||||
};
|
||||
struct llama_model {
|
||||
llama_hparams hparams;
|
||||
|
||||
struct ggml_tensor * tok_embeddings;
|
||||
|
||||
struct ggml_tensor * norm;
|
||||
struct ggml_tensor * output;
|
||||
|
||||
std::vector<llama_layer> layers;
|
||||
|
||||
// key + value memory
|
||||
struct ggml_tensor * memory_k;
|
||||
struct ggml_tensor * memory_v;
|
||||
|
||||
//
|
||||
struct ggml_context * ctx;
|
||||
std::map<std::string, struct ggml_tensor *> tensors;
|
||||
};
|
||||
struct llama_state
|
||||
{
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_predict_us = 0;
|
||||
|
||||
std::vector<float> logits;
|
||||
|
||||
mutable std::mt19937 rng;
|
||||
|
||||
std::vector<gpt_vocab::id> embd{};
|
||||
|
||||
int input_consumed = 0;
|
||||
std::vector<gpt_vocab::id> embd_inp;
|
||||
std::vector<gpt_vocab::id> last_n_tokens;
|
||||
int remaining_tokens = 0;
|
||||
int n_past = 0;
|
||||
size_t mem_per_token = 0;
|
||||
bool is_initialized = false;
|
||||
llama_state() {}
|
||||
|
||||
bool has_more_input() const {
|
||||
return input_consumed < embd_inp.size();
|
||||
}
|
||||
};
|
||||
struct llama_context
|
||||
{
|
||||
int64_t t_load_us = 0;
|
||||
int64_t t_start_us = 0;
|
||||
|
||||
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
|
||||
|
||||
llama_model model{};
|
||||
gpt_vocab vocab{};
|
||||
gpt_params params{};
|
||||
|
||||
std::unique_ptr<llama_state> state = nullptr;
|
||||
|
||||
// default constructor
|
||||
llama_context() = default;
|
||||
// constructor
|
||||
llama_context(llama_model&& model, gpt_vocab&& vocab, const gpt_params& params):
|
||||
t_load_us(0),
|
||||
t_start_us(0),
|
||||
wtype(ggml_type::GGML_TYPE_F16),
|
||||
model(std::move(model)),
|
||||
vocab(std::move(vocab)),
|
||||
params(params),
|
||||
state(std::make_unique<llama_state>())
|
||||
{
|
||||
}
|
||||
~llama_context(){
|
||||
ggml_free(model.ctx);
|
||||
}
|
||||
};
|
||||
|
||||
// load the model's weights from a file
|
||||
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
|
||||
|
@ -450,241 +540,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
|
|||
return true;
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - model: the model
|
||||
// - n_threads: number of threads to use
|
||||
// - n_past: the context size so far
|
||||
// - embd_inp: the embeddings of the tokens in the context
|
||||
// - embd_w: the predicted logits for the next token
|
||||
//
|
||||
// The GPT-J model requires about 16MB of memory per input token.
|
||||
//
|
||||
bool llama_eval(
|
||||
const llama_model & model,
|
||||
const int n_threads,
|
||||
const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_rot = hparams.n_embd/hparams.n_head;
|
||||
|
||||
const int d_key = n_embd/n_head;
|
||||
|
||||
// TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
|
||||
// static size_t buf_size = hparams.n_ctx*1024*1024;
|
||||
static size_t buf_size = 512u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ buf,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
|
||||
// store key and value to memory
|
||||
if (N >= 1) {
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_rope(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
||||
n_past, n_rot, 0),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_rope(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
n_past, n_rot, 1),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||
);
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3);
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
|
||||
// projection (no bias)
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].wo,
|
||||
cur);
|
||||
}
|
||||
|
||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w3,
|
||||
cur);
|
||||
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w1,
|
||||
cur);
|
||||
|
||||
// SILU activation
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w2,
|
||||
cur);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpFF);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
|
||||
// inpL = norm*inpL
|
||||
inpL = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.norm, inpL),
|
||||
inpL);
|
||||
}
|
||||
|
||||
// lm_head
|
||||
{
|
||||
inpL = ggml_mul_mat(ctx0, model.output, inpL);
|
||||
}
|
||||
|
||||
// logits -> probs
|
||||
//inpL = ggml_soft_max(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_print (&gf);
|
||||
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
|
||||
//}
|
||||
|
||||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
// return result for just the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return true;
|
||||
}
|
||||
bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) {
|
||||
ggml_type type = GGML_TYPE_Q4_1;
|
||||
|
||||
|
@ -940,3 +795,494 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/* External API */
|
||||
|
||||
const std::vector<gpt_vocab::id>& llama_context_get_embd(const llama_context& ctx) {
|
||||
return ctx.state->embd;
|
||||
}
|
||||
gpt_vocab& llama_context_get_vocab(llama_context& ctx) {
|
||||
return ctx.vocab;
|
||||
}
|
||||
bool llama_context_not_finished(const llama_context& ctx)
|
||||
{
|
||||
return ctx.state->remaining_tokens > 0;
|
||||
}
|
||||
const std::vector<gpt_vocab::id> llama_tokenize_text(const llama_context& ctx, const std::string& text) {
|
||||
return llama_tokenize(ctx.vocab, text, true);
|
||||
}
|
||||
const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens(const llama_context& ctx) {
|
||||
return ctx.state->last_n_tokens;
|
||||
}
|
||||
llama_context* llama_init_from_params(const gpt_params& params) {
|
||||
llama_model model{};
|
||||
gpt_vocab vocab{};
|
||||
|
||||
// Compute time taken to load model
|
||||
const int64_t t_start = ggml_time_us();
|
||||
bool ret = llama_model_load(params.model, model, vocab, 1024);
|
||||
const int64_t t_end = ggml_time_us();
|
||||
if(!ret)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
llama_context* ctx = new llama_context(std::move(model), std::move(vocab), params);
|
||||
ctx->t_load_us = t_end - t_start;
|
||||
return ctx;
|
||||
}
|
||||
void llama_free_context(llama_context* ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
const char * llama_print_system_info(void) {
|
||||
static std::string s;
|
||||
|
||||
s = "";
|
||||
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
|
||||
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
|
||||
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
||||
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
||||
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
||||
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
||||
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
||||
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
||||
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
||||
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
||||
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
||||
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
||||
|
||||
return s.c_str();
|
||||
}
|
||||
|
||||
void llama_print_context_info(const llama_context& ctx)
|
||||
{
|
||||
const gpt_params& params = ctx.params;
|
||||
const std::vector<gpt_vocab::id>& embd_inp = ctx.state->embd_inp;
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
||||
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], ctx.vocab.id_to_token.at(embd_inp[i]).c_str());
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void llama_print_end_stats(const llama_context& ctx)
|
||||
{
|
||||
const llama_state& state = *ctx.state;
|
||||
fprintf(stderr, "\n\n");
|
||||
fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, state.mem_per_token);
|
||||
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx.t_load_us/1000.0f);
|
||||
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, state.t_sample_us/1000.0f);
|
||||
fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, state.t_predict_us/1000.0f, state.t_predict_us/1000.0f/state.n_past);
|
||||
}
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - model: the model
|
||||
// - n_threads: number of threads to use
|
||||
// - n_past: the context size so far
|
||||
// - embd_inp: the embeddings of the tokens in the context
|
||||
// - embd_w: the predicted logits for the next token
|
||||
//
|
||||
// The GPT-J model requires about 16MB of memory per input token.
|
||||
//
|
||||
bool llama_eval(
|
||||
const llama_model & model,
|
||||
const int n_threads,
|
||||
const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_rot = hparams.n_embd/hparams.n_head;
|
||||
|
||||
const int d_key = n_embd/n_head;
|
||||
|
||||
// TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
|
||||
// static size_t buf_size = hparams.n_ctx*1024*1024;
|
||||
static size_t buf_size = 512u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ buf,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
|
||||
// store key and value to memory
|
||||
if (N >= 1) {
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_rope(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
||||
n_past, n_rot, 0),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_rope(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
n_past, n_rot, 1),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||
);
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3);
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
|
||||
// projection (no bias)
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].wo,
|
||||
cur);
|
||||
}
|
||||
|
||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w3,
|
||||
cur);
|
||||
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w1,
|
||||
cur);
|
||||
|
||||
// SILU activation
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w2,
|
||||
cur);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpFF);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
|
||||
// inpL = norm*inpL
|
||||
inpL = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.norm, inpL),
|
||||
inpL);
|
||||
}
|
||||
|
||||
// lm_head
|
||||
{
|
||||
inpL = ggml_mul_mat(ctx0, model.output, inpL);
|
||||
}
|
||||
|
||||
// logits -> probs
|
||||
//inpL = ggml_soft_max(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_print (&gf);
|
||||
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
|
||||
//}
|
||||
|
||||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
// return result for just the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_init_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing) {
|
||||
llama_state& state = *ctx.state;
|
||||
llama_model& model = ctx.model;
|
||||
const gpt_params& params = ctx.params;
|
||||
|
||||
if (clear_existing) {
|
||||
state.embd.clear();
|
||||
state.input_consumed = 0;
|
||||
state.embd_inp.clear();
|
||||
state.last_n_tokens.clear();
|
||||
state.remaining_tokens = 0;
|
||||
state.n_past = 0;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> line_inp = llama_tokenize_text(ctx, text);
|
||||
state.embd_inp.insert(state.embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||
|
||||
int n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) state.embd_inp.size());
|
||||
state.remaining_tokens = n_predict;
|
||||
|
||||
// determine the required inference memory per token:
|
||||
state.mem_per_token = 0;
|
||||
if(!llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, state.logits, state.mem_per_token))
|
||||
{
|
||||
fprintf(stderr, "Failed to predict with initial prompt\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
int last_n_size = params.repeat_last_n;
|
||||
state.last_n_tokens = std::vector<gpt_vocab::id>(last_n_size);
|
||||
std::fill(state.last_n_tokens.begin(), state.last_n_tokens.end(), 0);
|
||||
|
||||
state.is_initialized = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// @brief Injests a batch of input tokens into the context
|
||||
/// @param ctx
|
||||
void llama_injest_input_batch(llama_context& ctx)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
const gpt_params& params = ctx.params;
|
||||
|
||||
// Copy at most n_batch elements from embd_inp to embd
|
||||
size_t num_copied = std::min((size_t) params.n_batch, state.embd_inp.size() - state.input_consumed);
|
||||
std::copy(state.embd_inp.begin() + state.input_consumed,
|
||||
state.embd_inp.begin() + state.input_consumed + num_copied,
|
||||
std::back_inserter(state.embd));
|
||||
state.input_consumed += num_copied;
|
||||
|
||||
// Copy the last `repeat_last_n` elements copied into embd to last_n_tokens
|
||||
size_t num_copied_last_n = std::min(num_copied, (size_t) params.repeat_last_n);
|
||||
state.last_n_tokens.erase(state.last_n_tokens.begin(), state.last_n_tokens.begin()+num_copied_last_n);
|
||||
state.last_n_tokens.insert(state.last_n_tokens.end(), state.embd.end() - num_copied_last_n, state.embd.end());
|
||||
}
|
||||
|
||||
/// @brief Run the prediction step on ctx.embd and store result in ctx.state.logits
|
||||
/// @param ctx
|
||||
/// @return
|
||||
bool llama_predict(llama_context& ctx){
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
llama_state& state = *ctx.state;
|
||||
llama_model& model = ctx.model;
|
||||
const gpt_params& params = ctx.params;
|
||||
|
||||
if (!llama_eval(model, params.n_threads, state.n_past, state.embd, state.logits, state.mem_per_token)) {
|
||||
fprintf(stderr, "Failed to predict\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
state.t_predict_us += ggml_time_us() - t_start_us;
|
||||
return true;
|
||||
}
|
||||
/// @brief Sample a token from the logits
|
||||
/// @param ctx
|
||||
/// @return token id
|
||||
gpt_vocab::id llama_sample_token(llama_context& ctx)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
llama_model& model = ctx.model;
|
||||
const gpt_params& params = ctx.params;
|
||||
|
||||
const float top_k = params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float temp = params.temp;
|
||||
const float repeat_penalty = params.repeat_penalty;
|
||||
|
||||
const int n_vocab = model.hparams.n_vocab;
|
||||
|
||||
gpt_vocab::id id = 0;
|
||||
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
id = llama_sample_top_p_top_k(ctx.vocab, state.logits.data() + (state.logits.size() - n_vocab),
|
||||
state.last_n_tokens, repeat_penalty, top_k, top_p, temp, state.rng);
|
||||
|
||||
state.last_n_tokens.erase(state.last_n_tokens.begin());
|
||||
state.last_n_tokens.push_back(id);
|
||||
|
||||
state.t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
return id;
|
||||
}
|
||||
/// @brief Injest all input (in multiple batches) into model and run call predict()
|
||||
/// @param ctx
|
||||
bool llama_injest_input(llama_context& ctx, const std::string& text, bool clear_existing)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
|
||||
// Initialize context, tokenize text and clear existing state if necessary
|
||||
if(!state.is_initialized && !llama_init_context_with_prompt(ctx, text, clear_existing))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Injest the tokens into the model one batch at a time
|
||||
while (state.has_more_input())
|
||||
{
|
||||
llama_injest_input_batch(ctx);
|
||||
if (state.embd.size() >= 0) {
|
||||
if(!llama_predict(ctx))
|
||||
{
|
||||
return false;
|
||||
};
|
||||
}
|
||||
state.n_past += state.embd.size();
|
||||
state.embd.clear();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool llama_inference(llama_context& ctx, gpt_vocab::id& id) {
|
||||
llama_state& state = *ctx.state;
|
||||
|
||||
// Tokenize text if we are starting out
|
||||
if(!state.is_initialized)
|
||||
{
|
||||
fprintf(stderr, "State must be initialized before running inference");
|
||||
return false;
|
||||
}
|
||||
|
||||
// No more tokens to generate
|
||||
if (state.remaining_tokens <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Do prediction if we have enough tokens
|
||||
if (state.embd.size() > 0) {
|
||||
if(!llama_predict(ctx))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// sample a token
|
||||
id = llama_sample_token(ctx);
|
||||
// add it to the context
|
||||
state.embd.push_back(id);
|
||||
|
||||
state.n_past += 1;
|
||||
// decrement remaining sampling budget
|
||||
--state.remaining_tokens;
|
||||
|
||||
// end of text token
|
||||
if (state.embd.back() == 2) {
|
||||
state.remaining_tokens = 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
77
llama.h
77
llama.h
|
@ -3,8 +3,27 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "utils.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef LLAMA_SHARED
|
||||
# ifdef _WIN32
|
||||
# ifdef LLAMA_BUILD
|
||||
# define LLAMA_API __declspec(dllexport)
|
||||
# else
|
||||
# define LLAMA_API __declspec(dllimport)
|
||||
# endif
|
||||
# else
|
||||
# define LLAMA_API __attribute__ ((visibility ("default")))
|
||||
# endif
|
||||
#else
|
||||
# define LLAMA_API
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
||||
// default hparams (LLaMA 7B)
|
||||
struct llama_hparams {
|
||||
|
@ -18,50 +37,28 @@ struct llama_hparams {
|
|||
int32_t f16 = 1;
|
||||
};
|
||||
|
||||
struct llama_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * attention_norm;
|
||||
struct llama_context;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * wq;
|
||||
struct ggml_tensor * wk;
|
||||
struct ggml_tensor * wv;
|
||||
struct ggml_tensor * wo;
|
||||
void llama_free_context(llama_context* ctx);
|
||||
|
||||
// normalization
|
||||
struct ggml_tensor * ffn_norm;
|
||||
const std::vector<gpt_vocab::id>& llama_context_get_embd(const llama_context& ctx);
|
||||
gpt_vocab& llama_context_get_vocab(llama_context& ctx);
|
||||
bool llama_context_not_finished(const llama_context& ctx);
|
||||
const std::vector<gpt_vocab::id> llama_tokenize_text(const llama_context& ctx, const std::string& text);
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * w1;
|
||||
struct ggml_tensor * w2;
|
||||
struct ggml_tensor * w3;
|
||||
};
|
||||
const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens(const llama_context& ctx);
|
||||
bool llama_init_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing = true);
|
||||
|
||||
struct llama_model {
|
||||
llama_hparams hparams;
|
||||
// Various functions for loading a ggml LLaMA model.
|
||||
llama_context* llama_init_from_params(const gpt_params& params);
|
||||
|
||||
struct ggml_tensor * tok_embeddings;
|
||||
// Run inference on a LLaMA model using llama_context.
|
||||
std::vector<float> llama_eval(llama_context& ctx, const gpt_params& params, std::string& text);
|
||||
|
||||
struct ggml_tensor * norm;
|
||||
struct ggml_tensor * output;
|
||||
|
||||
std::vector<llama_layer> layers;
|
||||
|
||||
// key + value memory
|
||||
struct ggml_tensor * memory_k;
|
||||
struct ggml_tensor * memory_v;
|
||||
|
||||
//
|
||||
struct ggml_context * ctx;
|
||||
std::map<std::string, struct ggml_tensor *> tensors;
|
||||
};
|
||||
|
||||
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx);
|
||||
bool llama_eval(
|
||||
const llama_model & model,
|
||||
const int n_threads,
|
||||
const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token);
|
||||
bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype);
|
||||
|
||||
bool llama_injest_input(llama_context& ctx, const std::string& text, bool clear_existing = true);
|
||||
|
||||
bool llama_inference(llama_context& ctx, gpt_vocab::id& model_output);
|
||||
void llama_print_context_info(const llama_context& ctx);
|
||||
void llama_print_end_stats(const llama_context& ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue