Refactor llama.cpp and llama.h
This commit is contained in:
parent
05224ed472
commit
3839a08cee
2 changed files with 267 additions and 175 deletions
410
llama.cpp
410
llama.cpp
|
@ -72,33 +72,36 @@ struct llama_model {
|
|||
};
|
||||
struct llama_state
|
||||
{
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_predict_us = 0;
|
||||
// Timers
|
||||
struct timing {
|
||||
int64_t t_load_us = 0;
|
||||
|
||||
std::vector<float> logits;
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_predict_us = 0;
|
||||
} timing;
|
||||
|
||||
mutable std::mt19937 rng;
|
||||
// Random number generator
|
||||
std::mt19937 rng{};
|
||||
|
||||
// Tokens
|
||||
std::vector<gpt_vocab::id> embd{};
|
||||
std::vector<gpt_vocab::id> embd_inp{};
|
||||
std::vector<gpt_vocab::id> last_n_tokens{};
|
||||
|
||||
// Logits from inference
|
||||
std::vector<float> logits{};
|
||||
|
||||
// Counters
|
||||
int input_consumed = 0;
|
||||
std::vector<gpt_vocab::id> embd_inp;
|
||||
std::vector<gpt_vocab::id> last_n_tokens;
|
||||
int remaining_tokens = 0;
|
||||
int n_past = 0;
|
||||
size_t mem_per_token = 0;
|
||||
bool is_initialized = false;
|
||||
llama_state() {}
|
||||
|
||||
bool has_more_input() const {
|
||||
return input_consumed < embd_inp.size();
|
||||
}
|
||||
// Flag set after initialization
|
||||
bool is_initialized = false;
|
||||
};
|
||||
struct llama_context
|
||||
{
|
||||
int64_t t_load_us = 0;
|
||||
int64_t t_start_us = 0;
|
||||
|
||||
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
|
||||
|
||||
llama_model model{};
|
||||
|
@ -111,8 +114,6 @@ struct llama_context
|
|||
llama_context() = default;
|
||||
// constructor
|
||||
llama_context(llama_model&& model, gpt_vocab&& vocab, const gpt_params& params):
|
||||
t_load_us(0),
|
||||
t_start_us(0),
|
||||
wtype(ggml_type::GGML_TYPE_F16),
|
||||
model(std::move(model)),
|
||||
vocab(std::move(vocab)),
|
||||
|
@ -125,6 +126,7 @@ struct llama_context
|
|||
}
|
||||
};
|
||||
|
||||
/* Original code by @ggerganov */
|
||||
// load the model's weights from a file
|
||||
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
|
||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
@ -806,93 +808,6 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/* External API */
|
||||
|
||||
const std::vector<gpt_vocab::id>& llama_context_get_embedding(const llama_context& ctx) {
|
||||
return ctx.state->embd;
|
||||
}
|
||||
gpt_vocab& llama_context_get_vocab(llama_context& ctx) {
|
||||
return ctx.vocab;
|
||||
}
|
||||
bool llama_context_is_finished(const llama_context& ctx)
|
||||
{
|
||||
return ctx.state->remaining_tokens <= 0;
|
||||
}
|
||||
const std::vector<gpt_vocab::id> llama_tokenize_text(const llama_context& ctx, const std::string& text) {
|
||||
return llama_tokenize(ctx.vocab, text, true);
|
||||
}
|
||||
const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens(const llama_context& ctx) {
|
||||
return ctx.state->last_n_tokens;
|
||||
}
|
||||
llama_context* llama_init_from_params(const gpt_params& params) {
|
||||
llama_model model{};
|
||||
gpt_vocab vocab{};
|
||||
|
||||
// Compute time taken to load model
|
||||
const int64_t t_start = ggml_time_us();
|
||||
bool ret = llama_model_load(params.model, model, vocab, 1024);
|
||||
const int64_t t_end = ggml_time_us();
|
||||
if(!ret)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
llama_context* ctx = new llama_context(std::move(model), std::move(vocab), params);
|
||||
ctx->t_load_us = t_end - t_start;
|
||||
return ctx;
|
||||
}
|
||||
void llama_free_context(llama_context* ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
const char * llama_print_system_info(void) {
|
||||
static std::string s;
|
||||
|
||||
s = "";
|
||||
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
|
||||
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
|
||||
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
||||
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
||||
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
||||
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
||||
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
||||
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
||||
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
||||
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
||||
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
||||
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
||||
|
||||
return s.c_str();
|
||||
}
|
||||
|
||||
void llama_print_context_info(const llama_context& ctx)
|
||||
{
|
||||
const gpt_params& params = ctx.params;
|
||||
const std::vector<gpt_vocab::id>& embd_inp = ctx.state->embd_inp;
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
||||
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], ctx.vocab.id_to_token.at(embd_inp[i]).c_str());
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void llama_print_end_stats(const llama_context& ctx)
|
||||
{
|
||||
const llama_state& state = *ctx.state;
|
||||
fprintf(stderr, "\n\n");
|
||||
fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, state.mem_per_token);
|
||||
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx.t_load_us/1000.0f);
|
||||
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, state.t_sample_us/1000.0f);
|
||||
fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, state.t_predict_us/1000.0f, state.t_predict_us/1000.0f/state.n_past);
|
||||
}
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - model: the model
|
||||
|
@ -1129,25 +1044,56 @@ bool llama_eval(
|
|||
return true;
|
||||
}
|
||||
|
||||
bool llama_update_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing) {
|
||||
const char * llama_print_system_info(void) {
|
||||
static std::string s;
|
||||
|
||||
s = "";
|
||||
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
|
||||
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
|
||||
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
||||
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
||||
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
||||
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
||||
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
||||
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
||||
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
||||
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
||||
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
||||
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
||||
|
||||
return s.c_str();
|
||||
}
|
||||
|
||||
/* External API */
|
||||
/// @brief Initialize the context from a set of parameters
|
||||
/// @param params
|
||||
llama_context* llama_init_from_params(const gpt_params& params) {
|
||||
llama_model model{};
|
||||
gpt_vocab vocab{};
|
||||
|
||||
// Compute time taken to load model
|
||||
const int64_t t_start = ggml_time_us();
|
||||
bool ret = llama_model_load(params.model, model, vocab, 1024);
|
||||
const int64_t t_end = ggml_time_us();
|
||||
if(!ret)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
llama_context* ctx = new llama_context(std::move(model), std::move(vocab), params);
|
||||
ctx->state->timing.t_load_us = t_end - t_start;
|
||||
ctx->state->rng = std::mt19937(params.seed);
|
||||
return ctx;
|
||||
}
|
||||
/// @brief Prepare the context for inference
|
||||
/// @param ctx
|
||||
bool llama_prepare_context(llama_context& ctx)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
llama_model& model = ctx.model;
|
||||
const gpt_params& params = ctx.params;
|
||||
|
||||
if (clear_existing) {
|
||||
state.embd.clear();
|
||||
state.input_consumed = 0;
|
||||
state.embd_inp.clear();
|
||||
state.last_n_tokens.clear();
|
||||
state.remaining_tokens = 0;
|
||||
state.n_past = 0;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> line_inp = llama_tokenize_text(ctx, text);
|
||||
state.embd_inp.insert(state.embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||
gpt_params& params = ctx.params;
|
||||
|
||||
int n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) state.embd_inp.size());
|
||||
state.remaining_tokens = n_predict;
|
||||
params.n_predict = n_predict;
|
||||
|
||||
// determine the required inference memory per token:
|
||||
state.mem_per_token = 0;
|
||||
|
@ -1160,10 +1106,76 @@ bool llama_update_context_with_prompt(llama_context& ctx, const std::string& tex
|
|||
int last_n_size = params.repeat_last_n;
|
||||
state.last_n_tokens = std::vector<gpt_vocab::id>(last_n_size);
|
||||
std::fill(state.last_n_tokens.begin(), state.last_n_tokens.end(), 0);
|
||||
|
||||
state.is_initialized = true;
|
||||
state.remaining_tokens = params.n_predict;
|
||||
state.input_consumed = 0;
|
||||
return true;
|
||||
}
|
||||
/// @brief Free the context
|
||||
void llama_free_context(llama_context* ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
/* Getters and setters */
|
||||
/// @brief Get the embedding vector for the last token
|
||||
const std::vector<gpt_vocab::id>& llama_context_get_embedding(const llama_context& ctx) {
|
||||
return ctx.state->embd;
|
||||
}
|
||||
/// @brief Get the vector for the last token
|
||||
/// @param ctx
|
||||
gpt_vocab& llama_context_get_vocab(llama_context& ctx) {
|
||||
return ctx.vocab;
|
||||
}
|
||||
/// @brief Is the context finished?
|
||||
/// @param ctx
|
||||
bool llama_context_is_finished(const llama_context& ctx)
|
||||
{
|
||||
return ctx.state->remaining_tokens <= 0;
|
||||
}
|
||||
/// @brief Is the context finished?
|
||||
/// @param ctx
|
||||
void llama_reset_remaining_tokens(const llama_context& ctx)
|
||||
{
|
||||
ctx.state->remaining_tokens = ctx.params.n_predict;
|
||||
}
|
||||
|
||||
/// @brief Tokenize a text into a vector of ids
|
||||
/// @param ctx
|
||||
/// @param text
|
||||
const std::vector<gpt_vocab::id> llama_tokenize_text(const llama_context& ctx, const std::string& text) {
|
||||
// Make sure that the "beginning of string" token is not prefixed to the text
|
||||
return llama_tokenize(ctx.vocab, text, false);
|
||||
}
|
||||
const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens(const llama_context& ctx) {
|
||||
return ctx.state->last_n_tokens;
|
||||
}
|
||||
|
||||
/// @brief Adds the "beginning of string" token to the model input
|
||||
/// @param ctx
|
||||
void llama_add_bos(llama_context& ctx){
|
||||
// Add the "bos" token into the model input
|
||||
llama_state& state = *ctx.state;
|
||||
llama_model& model = ctx.model;
|
||||
const gpt_params& params = ctx.params;
|
||||
|
||||
const gpt_vocab::id bos_token = 1;
|
||||
state.embd_inp.push_back(bos_token);
|
||||
}
|
||||
|
||||
/// @brief Updates the context and appends new input text
|
||||
/// @param ctx
|
||||
/// @param text
|
||||
void llama_update_input(llama_context& ctx, const std::string& text)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
llama_model& model = ctx.model;
|
||||
const gpt_params& params = ctx.params;
|
||||
|
||||
std::vector<gpt_vocab::id> line_inp = llama_tokenize_text(ctx, text);
|
||||
|
||||
state.embd_inp.insert(state.embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||
state.remaining_tokens -= line_inp.size();
|
||||
}
|
||||
|
||||
/// @brief Ingests a batch of input tokens into the context
|
||||
/// @param ctx
|
||||
|
@ -1173,7 +1185,7 @@ void llama_ingest_input_batch(llama_context& ctx)
|
|||
const gpt_params& params = ctx.params;
|
||||
|
||||
// Copy at most n_batch elements from embd_inp to embd
|
||||
size_t num_copied = std::min((size_t) params.n_batch, state.embd_inp.size() - state.input_consumed);
|
||||
size_t num_copied = std::min((size_t) params.n_batch+1, state.embd_inp.size() - state.input_consumed);
|
||||
std::copy(state.embd_inp.begin() + state.input_consumed,
|
||||
state.embd_inp.begin() + state.input_consumed + num_copied,
|
||||
std::back_inserter(state.embd));
|
||||
|
@ -1185,21 +1197,60 @@ void llama_ingest_input_batch(llama_context& ctx)
|
|||
state.last_n_tokens.insert(state.last_n_tokens.end(), state.embd.end() - num_copied_last_n, state.embd.end());
|
||||
}
|
||||
|
||||
/// @brief Run the prediction step on ctx.embd and store result in ctx.state.logits
|
||||
/// @brief Returns true if there is unconsumed input in the context
|
||||
/// @param ctx
|
||||
/// @return
|
||||
bool llama_predict(llama_context& ctx){
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
bool llama_has_unconsumed_input(llama_context& ctx)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
return state.input_consumed < state.embd_inp.size();
|
||||
}
|
||||
|
||||
/// @brief Ingest all input (in multiple batches) into model and run call predict()
|
||||
/// @param ctx
|
||||
bool llama_ingest_all_pending_input(llama_context& ctx, bool print_tokens)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
const std::vector<gpt_vocab::id>& embd = state.embd;
|
||||
gpt_vocab& vocab = ctx.vocab;
|
||||
|
||||
if(!state.is_initialized)
|
||||
{
|
||||
fprintf(stderr, "Context must be initialized before ingesting input");
|
||||
return false;
|
||||
}
|
||||
|
||||
// ingest the tokens into the model one batch at a time
|
||||
while (llama_has_unconsumed_input(ctx))
|
||||
{
|
||||
llama_ingest_input_batch(ctx);
|
||||
if (print_tokens) {
|
||||
std::string s = llama_tokens_to_string(vocab, embd);
|
||||
printf("%s", s.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
llama_eval_model(ctx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/// @brief Evaluate the model with the current input batch
|
||||
/// @param ctx
|
||||
bool llama_eval_model(llama_context& ctx)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
llama_model& model = ctx.model;
|
||||
const gpt_params& params = ctx.params;
|
||||
|
||||
if (!llama_eval(model, params.n_threads, state.n_past, state.embd, state.logits, state.mem_per_token)) {
|
||||
fprintf(stderr, "Failed to predict\n");
|
||||
return false;
|
||||
}
|
||||
if (state.embd.size() > 0) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
state.t_predict_us += ggml_time_us() - t_start_us;
|
||||
if (!llama_eval(model, params.n_threads, state.n_past, state.embd, state.logits, state.mem_per_token)) {
|
||||
fprintf(stderr, "Failed to predict\n");
|
||||
return false;
|
||||
}
|
||||
state.timing.t_predict_us += ggml_time_us() - t_start_us;
|
||||
}
|
||||
state.n_past += state.embd.size();
|
||||
state.embd.clear();
|
||||
return true;
|
||||
}
|
||||
/// @brief Sample a token from the logits
|
||||
|
@ -1229,37 +1280,13 @@ gpt_vocab::id llama_sample_token(llama_context& ctx)
|
|||
state.last_n_tokens.erase(state.last_n_tokens.begin());
|
||||
state.last_n_tokens.push_back(id);
|
||||
|
||||
state.t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
state.timing.t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
return id;
|
||||
}
|
||||
/// @brief Ingest all input (in multiple batches) into model and run call predict()
|
||||
/// @param ctx
|
||||
bool llama_ingest_input(llama_context& ctx, const std::string& text, bool clear_existing)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
|
||||
// Initialize context, tokenize text and clear existing state if necessary
|
||||
if(!state.is_initialized && !llama_update_context_with_prompt(ctx, text, clear_existing))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// ingest the tokens into the model one batch at a time
|
||||
while (state.has_more_input())
|
||||
{
|
||||
llama_ingest_input_batch(ctx);
|
||||
if (state.embd.size() >= 0) {
|
||||
if(!llama_predict(ctx))
|
||||
{
|
||||
return false;
|
||||
};
|
||||
}
|
||||
state.n_past += state.embd.size();
|
||||
state.embd.clear();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/// @brief Run inference for one token and return the token id
|
||||
/// @param ctx
|
||||
/// @param id
|
||||
bool llama_infer(llama_context& ctx, gpt_vocab::id& id) {
|
||||
llama_state& state = *ctx.state;
|
||||
|
||||
|
@ -1275,25 +1302,68 @@ bool llama_infer(llama_context& ctx, gpt_vocab::id& id) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Do prediction if we have enough tokens
|
||||
if (state.embd.size() > 0) {
|
||||
if(!llama_predict(ctx))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// sample a token
|
||||
// Already predicted, so we just need to sample
|
||||
// sample a token
|
||||
id = llama_sample_token(ctx);
|
||||
|
||||
// add it to the context
|
||||
state.embd.push_back(id);
|
||||
|
||||
state.n_past += 1;
|
||||
// decrement remaining sampling budget
|
||||
--state.remaining_tokens;
|
||||
|
||||
// end of text token
|
||||
if (state.embd.back() == 2) {
|
||||
state.remaining_tokens = 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/// @brief Run inference for one token and return the token as a string
|
||||
/// @param ctx
|
||||
/// @param output
|
||||
/// @param is_end_of_text
|
||||
bool llama_infer(llama_context& ctx, std::string& output, bool& is_end_of_text) {
|
||||
// Call overloaded llama_infer and convert to string before returning
|
||||
gpt_vocab::id id_int;
|
||||
is_end_of_text = false;
|
||||
if(!llama_infer(ctx, id_int)){
|
||||
return false;
|
||||
}
|
||||
|
||||
// Pass through the "end of text" token to the user
|
||||
is_end_of_text = (id_int == EOS_TOKEN_ID);
|
||||
|
||||
// Make sure to pass in the newly generated token to the model as well
|
||||
llama_eval_model(ctx);
|
||||
output = ctx.vocab.id_to_token.at(id_int);
|
||||
return true;
|
||||
}
|
||||
bool llama_is_anti_prompt_present(llama_context& ctx, const std::vector<gpt_vocab::id>& antiprompt_inp)
|
||||
{
|
||||
llama_state& state = *ctx.state;
|
||||
return std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), state.last_n_tokens.rbegin());
|
||||
}
|
||||
|
||||
void llama_print_startup_stats(const llama_context& ctx)
|
||||
{
|
||||
const gpt_params& params = ctx.params;
|
||||
const std::vector<gpt_vocab::id>& embd_inp = ctx.state->embd_inp;
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
||||
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], ctx.vocab.id_to_token.at(embd_inp[i]).c_str());
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void llama_print_end_stats(const llama_context& ctx)
|
||||
{
|
||||
const llama_state& state = *ctx.state;
|
||||
fprintf(stderr, "\n\n");
|
||||
fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, state.mem_per_token);
|
||||
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, state.timing.t_load_us/1000.0f);
|
||||
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, state.timing.t_sample_us/1000.0f);
|
||||
fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, state.timing.t_predict_us/1000.0f, state.timing.t_predict_us/1000.0f/state.n_past);
|
||||
}
|
||||
|
|
32
llama.h
32
llama.h
|
@ -22,7 +22,7 @@
|
|||
# define LLAMA_API
|
||||
#endif
|
||||
|
||||
|
||||
static const int EOS_TOKEN_ID = 2;
|
||||
|
||||
|
||||
// default hparams (LLaMA 7B)
|
||||
|
@ -41,13 +41,35 @@ struct llama_context;
|
|||
|
||||
// Startup
|
||||
llama_context* llama_init_from_params(const gpt_params& params);
|
||||
bool llama_prepare_context(llama_context& ctx);
|
||||
|
||||
// Input processing and inference
|
||||
bool llama_ingest_input(llama_context& ctx, const std::string& text, bool clear_existing = true);
|
||||
bool llama_context_is_finished(const llama_context& ctx);
|
||||
bool llama_update_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing = true);
|
||||
// Tokenize text (never adds BOS)
|
||||
const std::vector<gpt_vocab::id> llama_tokenize_text(const llama_context& ctx, const std::string& text);
|
||||
// Queues up a BOS token to the model input
|
||||
void llama_add_bos(llama_context& ctx);
|
||||
// Queues up input text to the model input
|
||||
void llama_update_input(llama_context& ctx, const std::string& text);
|
||||
// Ingests input previously added using llama_update_input()
|
||||
void llama_ingest_input_batch(llama_context& ctx);
|
||||
// Ingests all input previously added using llama_update_input() in multiple batches
|
||||
// Batch size is determined by gpt_params::n_predict
|
||||
bool llama_ingest_all_pending_input(llama_context& ctx, bool print_tokens = false);
|
||||
// Checks if the model has unconsumed input to be ingested using llama_ingest_input_batch()
|
||||
bool llama_has_unconsumed_input(llama_context& ctx);
|
||||
// Checks if the model has an anti-prompt present its most recent output
|
||||
bool llama_is_anti_prompt_present(llama_context& ctx, const std::vector<gpt_vocab::id>& antiprompt_inp);
|
||||
|
||||
// Evaluate the model on a batch of input. Must call llama_ingest_input_batch() first.
|
||||
bool llama_eval_model(llama_context& ctx);
|
||||
// Checks if the model has finished generating output (i.e. has generated an EOS token or remaining_tokens == 0)
|
||||
bool llama_context_is_finished(const llama_context& ctx);
|
||||
// Resets the remaining_tokens counter to the value specified in the gpt_params
|
||||
void llama_reset_remaining_tokens(const llama_context& ctx);
|
||||
|
||||
// Overloaded functions to run inference and return either the model output or the decoded text
|
||||
bool llama_infer(llama_context& ctx, gpt_vocab::id& model_output);
|
||||
bool llama_infer(llama_context& ctx, std::string& output, bool& is_end_of_text);
|
||||
|
||||
// Teardown
|
||||
void llama_free_context(llama_context* ctx);
|
||||
|
@ -61,5 +83,5 @@ const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens(const llama_co
|
|||
bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype);
|
||||
|
||||
// Stats
|
||||
void llama_print_context_info(const llama_context& ctx);
|
||||
void llama_print_startup_stats(const llama_context& ctx);
|
||||
void llama_print_end_stats(const llama_context& ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue