llama : add llama_vocab, functions -> methods, naming (#11110)

* llama : functions -> methods (#11110)

* llama : add struct llama_vocab to the API (#11156)

ggml-ci

* hparams : move vocab params to llama_vocab (#11159)

ggml-ci

* vocab : more pimpl (#11165)

ggml-ci

* vocab : minor tokenization optimizations (#11160)

ggml-ci

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* lora : update API names (#11167)

ggml-ci

* llama : update API names to use correct prefix (#11174)

* llama : update API names to use correct prefix

ggml-ci

* cont

ggml-ci

* cont

ggml-ci

* minor [no ci]

* vocab : llama_vocab_add_[be]os -> llama_vocab_get_add_[be]os (#11174)

ggml-ci

* vocab : llama_vocab_n_vocab -> llama_vocab_n_tokens (#11174)

ggml-ci

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2025-01-12 11:32:42 +02:00 committed by GitHub
parent c05e8c9934
commit afa8a9ec9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
68 changed files with 5855 additions and 5400 deletions

View file

@ -50,7 +50,7 @@ int main(int argc, char ** argv) {
// ensure enough sequences are available
ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end());
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);

View file

@ -23,12 +23,12 @@ defer {
}
let model_params = llama_model_default_params()
guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), model_params) else {
guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else {
print("Failed to load model")
exit(1)
}
defer {
llama_free_model(model)
llama_model_free(model)
}
var tokens = tokenize(text: prompt, add_bos: true)
@ -141,7 +141,7 @@ while n_cur <= n_len {
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
i_batch[i] = -1
// print("")
if n_parallel > 1 {

View file

@ -48,10 +48,12 @@ int main(int argc, char ** argv) {
return 1;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
// tokenize the prompt
std::vector<llama_token> tokens_list;
tokens_list = common_tokenize(model, params.prompt, true);
tokens_list = common_tokenize(vocab, params.prompt, true);
const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel;
@ -62,7 +64,7 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_predict, n_parallel);
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_context * ctx = llama_init_from_model(model, ctx_params);
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
@ -121,7 +123,7 @@ int main(int argc, char ** argv) {
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = llama_token_bos(model);
decoder_start_token_id = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
@ -174,7 +176,7 @@ int main(int argc, char ** argv) {
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
// is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
i_batch[i] = -1;
LOG("\n");
if (n_parallel > 1) {

View file

@ -911,7 +911,7 @@ int main(int argc, char ** argv) {
load_vocab(params.fn_vocab_model, &config, &vocab);
struct my_llama_model model;
model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx);
model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx);
model.hparams.n_ctx = params.n_ctx;
model.hparams.n_embd = config.dim; //params.n_embd;
model.hparams.n_ff = config.hidden_dim;

View file

@ -273,7 +273,9 @@ struct tokenized_prompt {
size_t max_seq_len;
tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
tokens_pos = common_tokenize(ctx, pos, add_bos, true);
tokens_neg = common_tokenize(ctx, neg, add_bos, true);
max_seq_len = std::max(tokens_pos.size(), tokens_neg.size());
@ -421,8 +423,8 @@ int main(int argc, char ** argv) {
llama_context * ctx = llama_init.context.get();
// int n_ctx = llama_n_ctx(ctx);
int n_layers = llama_n_layer(model);
int n_embd = llama_n_embd(model);
int n_layers = llama_model_n_layer(model);
int n_embd = llama_model_n_embd(model);
// get model hint param (a.k.a model arch name)
char model_hint[128];

View file

@ -105,7 +105,9 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_ctx_train = llama_n_ctx_train(model);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
@ -148,7 +150,7 @@ int main(int argc, char ** argv) {
// check if the last token is SEP
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
for (auto & inp : inputs) {
if (inp.empty() || inp.back() != llama_token_sep(model)) {
if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) {
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__);
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
}
@ -181,7 +183,7 @@ int main(int argc, char ** argv) {
}
// allocate output
const int n_embd = llama_n_embd(model);
const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_embd_count * n_embd, 0);
float * emb = embeddings.data();

View file

@ -127,7 +127,10 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
}
static bool run(llama_context * ctx, const common_params & params) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);

View file

@ -8,7 +8,6 @@
#include <map>
#include <vector>
#include <string>
#include <thread>
#include <fstream>
static bool g_verbose = false;
@ -130,7 +129,7 @@ struct lora_merge_ctx {
lora_merge_ctx(
std::string & base_fname,
std::vector<common_lora_adapter_info> & lora_files,
std::vector<common_adapter_lora_info> & lora_files,
std::string & outfile,
int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) {
fout.exceptions(std::ofstream::failbit); // fail fast on write errors

View file

@ -11,6 +11,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
std::vector<std::vector<float>> result;
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
@ -19,16 +20,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const std::string input_string = instruction + sentences[i];
std::vector<llama_token> inputs = common_tokenize(model, input_string, true, false);
std::vector<llama_token> inputs = common_tokenize(vocab, input_string, true, false);
const int32_t n_toks = inputs.size();
// GritLM seems to have EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(model));
// inputs.push_back(llama_vocab_eos(vocab));
// we want to ignore instruction tokens for mean pooling
const int32_t n_inst = common_tokenize(model, instruction, true, false).size();
const int32_t n_inst = common_tokenize(vocab, instruction, true, false).size();
#ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample
@ -52,7 +53,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
llama_decode(ctx, batch);
// get embedding dimensions
uint64_t n_embd = llama_n_embd(model);
uint64_t n_embd = llama_model_n_embd(model);
// allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f);
@ -97,7 +98,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
std::string result;
const llama_model * model = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(model);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_token eos_token = llama_vocab_eos(vocab);
llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, false);
@ -105,7 +108,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = common_tokenize(model, prompt, false, true);
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
int32_t i_current_token = 0;
while (true) {
@ -168,7 +171,7 @@ int main(int argc, char * argv[]) {
llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams);
// create generation context
llama_context * ctx = llama_new_context_with_model(model, cparams);
llama_context * ctx = llama_init_from_model(model, cparams);
auto sparams = llama_sampler_chain_default_params();
@ -197,7 +200,7 @@ int main(int argc, char * argv[]) {
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
const int n_embd = llama_n_embd(model);
const int n_embd = llama_model_n_embd(model);
const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);

View file

@ -7,7 +7,6 @@
#include <cstdio>
#include <cstring>
#include <ctime>
#include <sstream>
#include <thread>
#include <mutex>
#include <vector>
@ -40,7 +39,7 @@ public:
void set_params(common_params params) { m_params = std::move(params); }
bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data);
void save_imatrix(int ncall = -1) const;
bool load_imatrix(const char * file_name);
bool load_imatrix(const char * fname);
private:
std::unordered_map<std::string, Stats> m_stats;
common_params m_params;
@ -429,10 +428,13 @@ static void process_logits(
}
static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
const int n_ctx = llama_n_ctx(ctx);
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
auto tim1 = std::chrono::high_resolution_clock::now();
LOG_INF("%s: tokenizing the input ..\n", __func__);
@ -468,7 +470,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_vocab_n_tokens(vocab);
const int n_batch = params.n_batch;
int count = 0;
@ -508,7 +510,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
tokens[batch_start] = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
@ -627,7 +629,7 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);

View file

@ -139,7 +139,9 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_ctx_train = llama_n_ctx_train(model);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG_DBG("n_ctx: %d\n", n_ctx);
@ -152,28 +154,28 @@ int main(int argc, char ** argv) {
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
}
const bool add_bos = llama_add_bos_token(model);
GGML_ASSERT(!llama_add_eos_token(model));
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
std::vector<llama_token> embd_inp;
std::vector<llama_token> embd_end;
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
GGML_ASSERT(llama_token_fim_pre(model) >= 0);
GGML_ASSERT(llama_token_fim_suf(model) >= 0);
GGML_ASSERT(llama_vocab_fim_pre(vocab) >= 0);
GGML_ASSERT(llama_vocab_fim_suf(vocab) >= 0);
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab));
inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab));
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
}
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
const llama_token middle_token = llama_token_fim_mid(model);
const llama_token middle_token = llama_vocab_fim_mid(vocab);
if (middle_token >= 0) {
embd_inp.push_back(middle_token);
}
@ -185,7 +187,7 @@ int main(int argc, char ** argv) {
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(model));
embd_inp.push_back(llama_vocab_bos(vocab));
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
}
@ -420,10 +422,10 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
if ((common_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
if ((common_sampler_last(smpl) == llama_vocab_eot(vocab) || is_interacting) && params.interactive){
if (is_interacting && !params.interactive_first) {
// print an eot token
LOG("%s", common_token_to_piece(ctx, llama_token_eot(model)).c_str());
LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str());
}
LOG("\n");
console::set_display(console::user_input);
@ -463,13 +465,13 @@ int main(int argc, char ** argv) {
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab));
inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab));
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
}
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
@ -484,7 +486,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of generation tokens in interactive mode
else if (llama_token_is_eog(model, common_sampler_last(smpl))) {
else if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) {
LOG_DBG("found EOS token\n");
if (params.interactive) {
@ -500,7 +502,7 @@ int main(int argc, char ** argv) {
if (params.input_prefix_bos) {
LOG_DBG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(model));
embd_inp.push_back(llama_vocab_bos(vocab));
}
std::string buffer;
@ -563,7 +565,7 @@ int main(int argc, char ** argv) {
}
// end of generation
if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !params.interactive) {
if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !params.interactive) {
break;
}
@ -575,7 +577,7 @@ int main(int argc, char ** argv) {
}
}
if (!params.interactive && n_remain <= 0) {
LOG("%s", common_token_to_piece(ctx, llama_token_eot(model)).c_str());
LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str());
}
LOG("\n");

View file

@ -1401,7 +1401,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
const int32_t n_vocab = llama_n_vocab(model);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
std::vector<llama_token> tokens(n_batch);
@ -1409,7 +1410,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch);
tokens[0] = n_processed == 0 && llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
tokens[0] = n_processed == 0 && llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
@ -1424,9 +1425,10 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
const int32_t n_vocab = llama_n_vocab(model);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1));
@ -1537,7 +1539,7 @@ int main(int argc, char ** argv) {
prev_inst = &inst;
}
llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams());
llama_context * ctx = llama_init_from_model(lmodel, inst.to_llama_cparams());
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str());
llama_model_free(lmodel);

View file

@ -87,7 +87,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi
auto path_to_model = env->GetStringUTFChars(filename, 0);
LOGi("Loading model from %s", path_to_model);
auto model = llama_load_model_from_file(path_to_model, model_params);
auto model = llama_model_load_from_file(path_to_model, model_params);
env->ReleaseStringUTFChars(filename, path_to_model);
if (!model) {
@ -102,7 +102,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
llama_free_model(reinterpret_cast<llama_model *>(model));
llama_model_free(reinterpret_cast<llama_model *>(model));
}
extern "C"
@ -405,6 +405,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
const auto model = llama_get_model(context);
const auto vocab = llama_model_get_vocab(model);
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
@ -414,7 +415,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
return nullptr;
}

View file

@ -52,8 +52,8 @@ actor LlamaContext {
deinit {
llama_sampler_free(sampling)
llama_batch_free(batch)
llama_model_free(model)
llama_free(context)
llama_free_model(model)
llama_backend_free()
}
@ -65,7 +65,7 @@ actor LlamaContext {
model_params.n_gpu_layers = 0
print("Running on simulator, force use n_gpu_layers = 0")
#endif
let model = llama_load_model_from_file(path, model_params)
let model = llama_model_load_from_file(path, model_params)
guard let model else {
print("Could not load model at \(path)")
throw LlamaError.couldNotInitializeContext
@ -151,7 +151,7 @@ actor LlamaContext {
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")
is_done = true
let new_token_str = String(cString: temporary_invalid_cchars + [0])

View file

@ -47,8 +47,12 @@ static const char * sample(struct common_sampler * smpl,
int * n_past) {
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
common_sampler_accept(smpl, id, true);
const llama_model * model = llama_get_model(ctx_llama);
const llama_vocab * vocab = llama_model_get_vocab(model);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
if (llama_vocab_is_eog(vocab, id)) {
ret = "</s>";
} else {
ret = common_token_to_piece(ctx_llama, id);
@ -239,11 +243,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
llama_context_params ctx_params = common_context_params_to_llama(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
if (ctx_llama == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__);

View file

@ -384,7 +384,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
// make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama));
auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
if (n_image_embd != n_llama_embd) {
LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
@ -456,7 +456,7 @@ struct llava_embd_batch {
};
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
int n_eval = image_embed->n_image_pos - i;

View file

@ -54,7 +54,7 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
ctx_params.n_ctx = params->n_ctx;
}
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
if (ctx_llama == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__);
@ -167,8 +167,12 @@ static const char * sample(struct common_sampler * smpl,
int * n_past) {
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
common_sampler_accept(smpl, id, true);
const llama_model * model = llama_get_model(ctx_llama);
const llama_vocab * vocab = llama_model_get_vocab(model);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
if (llama_vocab_is_eog(vocab, id)) {
ret = "</s>";
} else {
ret = common_token_to_piece(ctx_llama, id);

View file

@ -27,7 +27,7 @@
static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
const int patch_size = 14 * 2;
const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
@ -132,8 +132,12 @@ static const char * sample(struct common_sampler * smpl,
int * n_past, int * st_pos_id) {
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
common_sampler_accept(smpl, id, true);
const llama_model * model = llama_get_model(ctx_llama);
const llama_vocab * vocab = llama_model_get_vocab(model);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
if (llama_vocab_is_eog(vocab, id)) {
ret = "</s>";
} else {
ret = common_token_to_piece(ctx_llama, id);
@ -328,11 +332,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
llama_context_params ctx_params = common_context_params_to_llama(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
if (ctx_llama == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__);
@ -481,7 +484,7 @@ static void debug_test_mrope_2d() {
}
static void debug_dump_img_embed(struct llava_context * ctx_llava) {
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
int ne = n_embd * 4;
float vals[56 * 56 * 3];
// float embd[ne];

View file

@ -61,6 +61,8 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model);
// Tokenize the prompt
std::vector<llama_token> inp;
std::vector<llama_token> all;
@ -147,7 +149,7 @@ int main(int argc, char ** argv) {
}
// here we keep adding new n-grams as we go
ngram_container ngrams_observed(llama_n_vocab(model), N, G);
ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G);
// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
@ -297,7 +299,7 @@ int main(int argc, char ** argv) {
}
fflush(stdout);
if (llama_token_is_eog(model, id)) {
if (llama_vocab_is_eog(vocab, id)) {
has_eos = true;
}

View file

@ -36,6 +36,8 @@ int main(int argc, char ** argv){
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model);
// tokenize the prompt
std::vector<llama_token> inp;
inp = common_tokenize(ctx, params.prompt, true, true);
@ -136,7 +138,7 @@ int main(int argc, char ** argv){
LOG("%s", token_str.c_str());
}
if (llama_token_is_eog(model, id)) {
if (llama_vocab_is_eog(vocab, id)) {
has_eos = true;
}

View file

@ -5,7 +5,6 @@
#include "sampling.h"
#include "llama.h"
#include <cassert>
#include <cstdio>
#include <cstring>
#include <ctime>
@ -163,6 +162,8 @@ int main(int argc, char ** argv) {
return 1;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
@ -196,7 +197,7 @@ int main(int argc, char ** argv) {
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
if (n_ctx > n_ctx_train) {
@ -241,9 +242,9 @@ int main(int argc, char ** argv) {
}
}
const bool add_bos = llama_add_bos_token(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
if (!llama_model_has_encoder(model)) {
GGML_ASSERT(!llama_add_eos_token(model));
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
}
LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos);
@ -269,7 +270,7 @@ int main(int argc, char ** argv) {
// Should not run without any tokens
if (embd_inp.empty()) {
if (add_bos) {
embd_inp.push_back(llama_token_bos(model));
embd_inp.push_back(llama_vocab_bos(vocab));
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
} else {
LOG_ERR("input is empty\n");
@ -495,7 +496,7 @@ int main(int argc, char ** argv) {
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = llama_token_bos(model);
decoder_start_token_id = llama_vocab_bos(vocab);
}
embd_inp.clear();
@ -742,7 +743,7 @@ int main(int argc, char ** argv) {
}
// deal with end of generation tokens in interactive mode
if (llama_token_is_eog(model, common_sampler_last(smpl))) {
if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) {
LOG_DBG("found an EOG token\n");
if (params.interactive) {
@ -776,7 +777,7 @@ int main(int argc, char ** argv) {
if (params.input_prefix_bos) {
LOG_DBG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(model));
embd_inp.push_back(llama_vocab_bos(vocab));
}
std::string buffer;
@ -830,8 +831,8 @@ int main(int argc, char ** argv) {
// if user stop generation mid-way, we must add EOT to finish model's last response
if (need_insert_eot && format_chat) {
llama_token eot = llama_token_eot(model);
embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_token_eos(model) : eot);
llama_token eot = llama_vocab_eot(vocab);
embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_vocab_eos(vocab) : eot);
need_insert_eot = false;
}
@ -866,7 +867,7 @@ int main(int argc, char ** argv) {
}
// end of generation
if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.interactive)) {
if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) {
LOG(" [end of text]\n");
break;
}

View file

@ -135,6 +135,8 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model);
// load the prompts from an external file if there are any
if (params.prompt.empty()) {
LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@ -358,7 +360,7 @@ int main(int argc, char ** argv) {
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
if (client.n_decoded > 2 &&
(llama_token_is_eog(model, id) ||
(llama_vocab_is_eog(vocab, id) ||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
client.response.find("User:") != std::string::npos ||
client.response.find('\n') != std::string::npos)) {

View file

@ -70,15 +70,17 @@ int main(int argc, char ** argv) {
return 1;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
// initialize the context
llama_context_params ctx_params = common_context_params_to_llama(params);
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
ctx_params.n_ctx = llama_model_n_ctx_train(model)*n_grp + n_keep;
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__);
return 1;
@ -223,7 +225,7 @@ int main(int argc, char ** argv) {
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
LOG("\n");
break;

View file

@ -296,8 +296,11 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
LOG_INF("%s: tokenizing the input ..\n", __func__);
@ -338,7 +341,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_vocab_n_tokens(vocab);
int count = 0;
double nll = 0.0;
@ -382,7 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
tokens[batch_start] = llama_vocab_bos(vocab);
}
const auto * batch_logits = llama_get_logits(ctx);
@ -444,8 +447,11 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
std::ofstream logits_stream;
if (!params.logits_file.empty()) {
@ -485,7 +491,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_vocab_n_tokens(vocab);
int count = 0;
double nll = 0.0;
@ -557,7 +563,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
tokens[seq_start] = llama_vocab_bos(vocab);
}
for (int k = 0; k < batch_size; ++k) {
@ -732,6 +738,9 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
}
static void hellaswag_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
// Calculates hellaswag score (acc_norm) from prompt
//
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
@ -765,7 +774,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
size_t hs_task_count = prompt_lines.size()/6;
LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
const bool is_spm = llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM;
LOG_INF("================================= is_spm = %d\n", is_spm);
// The tasks should be randomized so the score stabilizes quickly.
@ -848,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_vocab_n_tokens(vocab);
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1072,6 +1081,8 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
*
*/
static void winogrande_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
constexpr int k_min_trailing_ctx = 3;
@ -1130,7 +1141,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_vocab_n_tokens(vocab);
const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1374,6 +1385,8 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic
// https://huggingface.co/datasets/truthful_qa
//
static void multiple_choice_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
std::istringstream strstream(params.prompt);
uint32_t n_task;
@ -1482,7 +1495,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_vocab_n_tokens(vocab);
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1655,6 +1668,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
}
static void kl_divergence(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.logits_file.empty()) {
LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
return;
@ -1688,8 +1704,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
return;
}
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
if (n_vocab != llama_vocab_n_tokens(vocab)) {
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_vocab_n_tokens(vocab));
}
std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
@ -1701,8 +1717,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
const int n_batch = params.n_batch;
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
const int nv = 2*((n_vocab + 1)/2) + 4;
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
@ -1761,7 +1777,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
tokens[batch_start] = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
@ -1995,7 +2011,7 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",

View file

@ -319,7 +319,7 @@ int main(int argc, char ** argv) {
auto cparams = llama_context_default_params();
cparams.n_ctx = 256;
ctx = llama_new_context_with_model(model, cparams);
ctx = llama_init_from_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());

View file

@ -159,7 +159,9 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_ctx_train = llama_n_ctx_train(model);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
@ -192,8 +194,8 @@ int main(int argc, char ** argv) {
return 1;
}
// add eos if not present
if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) {
inp.push_back(llama_token_eos(model));
if (llama_vocab_eos(vocab) >= 0 && (inp.empty() || inp.back() != llama_vocab_eos(vocab))) {
inp.push_back(llama_vocab_eos(vocab));
}
chunk.tokens = inp;
}
@ -215,7 +217,7 @@ int main(int argc, char ** argv) {
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output
const int n_embd = llama_n_embd(model);
const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_chunks * n_embd, 0);
float * emb = embeddings.data();

View file

@ -685,7 +685,7 @@ class LlamaData {
// Initializes the context with the specified parameters
llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
llama_context_ptr context(llama_new_context_with_model(model.get(), opt.ctx_params));
llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params));
if (!context) {
printe("%s: error: failed to create the llama_context\n", __func__);
}
@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) {
int result = llama_chat_apply_template(
llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append,
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
}
@ -726,11 +726,11 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
}
// Function to tokenize the prompt
static int tokenize_prompt(const llama_model_ptr & model, const std::string & prompt,
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(model.get(), prompt.c_str(), prompt.size(), NULL, 0, true, true);
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
true) < 0) {
printe("failed to tokenize the prompt\n");
return -1;
@ -753,9 +753,9 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch &
}
// convert the token to a string
static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) {
static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) {
char buf[256];
int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true);
int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true);
if (n < 0) {
printe("failed to convert token to piece\n");
return 1;
@ -773,8 +773,10 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st
// helper function to evaluate a prompt and generate a response
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
std::vector<llama_token> tokens;
if (tokenize_prompt(llama_data.model, prompt, tokens) < 0) {
if (tokenize_prompt(vocab, prompt, tokens) < 0) {
return 1;
}
@ -790,12 +792,12 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
// sample the next token, check is it an end of generation?
new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);
if (llama_token_is_eog(llama_data.model.get(), new_token_id)) {
if (llama_vocab_is_eog(vocab, new_token_id)) {
break;
}
std::string piece;
if (convert_token_to_string(llama_data.model, new_token_id, piece)) {
if (convert_token_to_string(vocab, new_token_id, piece)) {
return 1;
}

View file

@ -97,7 +97,7 @@ int main(int argc, char ** argv) {
printf("\n\n");
// make new context
llama_context * ctx2 = llama_new_context_with_model(model, common_context_params_to_llama(params));
llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params));
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
@ -154,7 +154,7 @@ int main(int argc, char ** argv) {
}
// make new context
llama_context * ctx3 = llama_new_context_with_model(model, common_context_params_to_llama(params));
llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);

View file

@ -98,7 +98,7 @@ struct slot_params {
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<common_lora_adapter_info> lora;
std::vector<common_adapter_lora_info> lora;
std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
@ -198,15 +198,17 @@ struct server_task {
bool metrics_reset_bucket = false;
// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_lora_adapter_info> set_lora;
std::vector<common_adapter_lora_info> set_lora;
server_task(server_task_type type) : type(type) {}
static slot_params params_from_json_cmpl(
const llama_model * model,
const llama_context * ctx,
const common_params & params_base,
const json & data) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
slot_params params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
@ -329,7 +331,7 @@ struct server_task {
const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) {
@ -348,7 +350,7 @@ struct server_task {
params.sampling.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
for (auto tok : toks) {
params.sampling.logit_bias.push_back({tok, bias});
}
@ -1131,7 +1133,7 @@ struct server_slot {
common_speculative * spec = nullptr;
std::vector<common_lora_adapter_info> lora;
std::vector<common_adapter_lora_info> lora;
// the index relative to completion multi-task request
size_t index = 0;
@ -1633,6 +1635,8 @@ struct server_context {
llama_model * model = nullptr;
llama_context * ctx = nullptr;
const llama_vocab * vocab = nullptr;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
@ -1690,10 +1694,12 @@ struct server_context {
return false;
}
vocab = llama_model_get_vocab(model);
n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_add_bos_token(model);
has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
if (!params_base.speculative.model.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
@ -1736,7 +1742,8 @@ struct server_context {
bool validate_builtin_chat_template() const {
llama_chat_message chat[] = {{"user", "test"}};
int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
const char * tmpl = llama_model_chat_template(model);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
}
@ -1756,7 +1763,7 @@ struct server_context {
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return;
@ -1891,7 +1898,7 @@ struct server_context {
}
if (slot.params.ignore_eos && has_eos_token) {
slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
}
{
@ -2047,14 +2054,14 @@ struct server_context {
slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
}
if (llama_token_is_eog(model, result.tok)) {
if (llama_vocab_is_eog(vocab, result.tok)) {
slot.stop = STOP_TYPE_EOS;
slot.has_next_token = false;
SLT_DBG(slot, "%s", "stopped by EOS\n");
}
const auto n_ctx_train = llama_n_ctx_train(model);
const auto n_ctx_train = llama_model_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
slot.truncated = true;
@ -2074,7 +2081,7 @@ struct server_context {
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.params.sampling.n_probs;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
size_t n_vocab = llama_vocab_n_tokens(vocab);
if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size;
@ -2225,7 +2232,7 @@ struct server_context {
res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat;
const int n_embd = llama_n_embd(model);
const int n_embd = llama_model_n_embd(model);
std::vector<float> embd_res(n_embd, 0.0f);
@ -2927,7 +2934,7 @@ struct server_context {
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, slot_batched->is_non_causal());
// apply lora, only need to do it once per batch
common_lora_adapters_apply(ctx, slot_batched->lora);
common_set_adapter_lora(ctx, slot_batched->lora);
}
// process the created batch of tokens
@ -3129,12 +3136,12 @@ struct server_context {
json model_meta() const {
return json {
{"vocab_type", llama_vocab_type (model)},
{"n_vocab", llama_n_vocab (model)},
{"n_ctx_train", llama_n_ctx_train (model)},
{"n_embd", llama_n_embd (model)},
{"n_params", llama_model_n_params(model)},
{"size", llama_model_size (model)},
{"vocab_type", llama_vocab_type (vocab)},
{"n_vocab", llama_vocab_n_tokens (vocab)},
{"n_ctx_train", llama_model_n_ctx_train(model)},
{"n_embd", llama_model_n_embd (model)},
{"n_params", llama_model_n_params (model)},
{"size", llama_model_size (model)},
};
}
};
@ -3639,7 +3646,7 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks;
try {
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type);
@ -3649,7 +3656,6 @@ int main(int argc, char ** argv) {
task.prompt_tokens = std::move(tokenized_prompts[i]);
task.params = server_task::params_from_json_cmpl(
ctx_server.model,
ctx_server.ctx,
ctx_server.params_base,
data);
@ -3745,13 +3751,13 @@ int main(int argc, char ** argv) {
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
// check model compatibility
std::string err;
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "prefix token is missing. ";
}
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "suffix token is missing. ";
}
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "middle token is missing. ";
}
if (!err.empty()) {
@ -3797,10 +3803,10 @@ int main(int argc, char ** argv) {
data["input_extra"] = input_extra; // default to empty array if it's not exist
std::string prompt = json_value(data, "prompt", std::string());
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, false, true);
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
data["prompt"] = format_infill(
ctx_server.ctx,
ctx_server.vocab,
data.at("input_prefix"),
data.at("input_suffix"),
data.at("input_extra"),
@ -3857,7 +3863,7 @@ int main(int argc, char ** argv) {
const bool add_special = json_value(body, "add_special", false);
const bool with_pieces = json_value(body, "with_pieces", false);
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true);
if (with_pieces) {
for (const auto& token : tokens) {
@ -3933,7 +3939,7 @@ int main(int argc, char ** argv) {
}
}
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) {
@ -4033,20 +4039,20 @@ int main(int argc, char ** argv) {
return;
}
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0];
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
// create and queue the task
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks;
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
tasks.push_back(task);
}

View file

@ -118,7 +118,7 @@ static json json_get_nested_values(const std::vector<std::string> & paths, const
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens;
@ -131,10 +131,10 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
llama_tokens p;
if (first) {
p = common_tokenize(ctx, s, add_special, parse_special);
p = common_tokenize(vocab, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(ctx, s, false, parse_special);
p = common_tokenize(vocab, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
@ -148,7 +148,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
prompt_tokens = common_tokenize(vocab, s, add_special, parse_special);
}
return prompt_tokens;
@ -166,11 +166,11 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
* - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
static std::vector<llama_tokens> tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
std::vector<llama_tokens> result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special));
} else if (json_is_array_of_numbers(json_prompt)) {
// array of tokens
result.push_back(json_prompt.get<llama_tokens>());
@ -179,7 +179,7 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
result.reserve(json_prompt.size());
for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) {
// array of tokens
result.push_back(p.get<llama_tokens>());
@ -231,21 +231,23 @@ static size_t validate_utf8(const std::string& text) {
//
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result;
result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_token_bos(model));
result.push_back(llama_vocab_bos(vocab));
result.insert(result.end(), query.begin(), query.end());
result.push_back(llama_token_eos(model));
result.push_back(llama_token_sep(model));
result.push_back(llama_vocab_eos(vocab));
result.push_back(llama_vocab_sep(vocab));
result.insert(result.end(), doc.begin(), doc.end());
result.push_back(llama_token_eos(model));
result.push_back(llama_vocab_eos(vocab));
return result;
}
// format infill task
static llama_tokens format_infill(
const llama_context * ctx,
const llama_vocab * vocab,
const json & input_prefix,
const json & input_suffix,
const json & input_extra,
@ -272,15 +274,14 @@ static llama_tokens format_infill(
llama_tokens extra_tokens;
extra_tokens.reserve(n_ctx);
auto model = llama_get_model(ctx);
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false);
auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false);
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) {
// TODO: make project name an input
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false);
extra_tokens.push_back(llama_token_fim_rep(model));
extra_tokens.push_back(llama_vocab_fim_rep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : input_extra) {
@ -288,28 +289,28 @@ static llama_tokens format_infill(
const std::string text = json_value(chunk, "text", std::string());
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
const auto chunk_tokens = common_tokenize(ctx, text, false, false);
const auto chunk_tokens = common_tokenize(vocab, text, false, false);
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
@ -325,15 +326,15 @@ static llama_tokens format_infill(
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab));
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
if (llama_vocab_get_add_bos(vocab)) {
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
}
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
@ -342,7 +343,7 @@ static llama_tokens format_infill(
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));
embd_inp.push_back(llama_vocab_fim_mid(vocab));
return embd_inp;
}
@ -764,14 +765,18 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
return data;
}
static std::string safe_json_to_str(json data) {
static std::string safe_json_to_str(const json & data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace);
}
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
std::vector<llama_token_data> cur;
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@ -799,8 +804,8 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
}
static bool are_lora_equal(
const std::vector<common_lora_adapter_info> & l1,
const std::vector<common_lora_adapter_info> & l2) {
const std::vector<common_adapter_lora_info> & l1,
const std::vector<common_adapter_lora_info> & l2) {
if (l1.size() != l2.size()) {
return false;
}
@ -814,10 +819,10 @@ static bool are_lora_equal(
}
// parse lora config from JSON request, returned a copy of lora_base with updated scale
static std::vector<common_lora_adapter_info> parse_lora_request(
const std::vector<common_lora_adapter_info> & lora_base,
static std::vector<common_adapter_lora_info> parse_lora_request(
const std::vector<common_adapter_lora_info> & lora_base,
const json & data) {
std::vector<common_lora_adapter_info> lora(lora_base);
std::vector<common_adapter_lora_info> lora(lora_base);
int max_idx = lora.size();
// clear existing value

View file

@ -75,12 +75,14 @@ int main(int argc, char ** argv) {
return 1;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
// initialize the context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = n_ctx;
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
@ -97,9 +99,9 @@ int main(int argc, char ** argv) {
std::string response;
// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n");
}
@ -124,13 +126,13 @@ int main(int argc, char ** argv) {
new_token_id = llama_sampler_sample(smpl, ctx, -1);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id)) {
if (llama_vocab_is_eog(vocab, new_token_id)) {
break;
}
// convert the token to a string, print it and add it to the response
char buf[256];
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true);
int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
if (n < 0) {
GGML_ABORT("failed to convert token to piece\n");
}
@ -159,12 +161,14 @@ int main(int argc, char ** argv) {
break;
}
const char * tmpl = llama_model_chat_template(model);
// add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())});
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
if (new_len > (int)formatted.size()) {
formatted.resize(new_len);
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
}
if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n");
@ -181,7 +185,7 @@ int main(int argc, char ** argv) {
// add the response to the messages
messages.push_back({"assistant", strdup(response.c_str())});
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);
prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0);
if (prev_len < 0) {
fprintf(stderr, "failed to apply the chat template\n");
return 1;

View file

@ -84,6 +84,7 @@ int main(int argc, char ** argv) {
model_params.n_gpu_layers = ngl;
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
const llama_vocab * vocab = llama_model_get_vocab(model);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
@ -93,11 +94,11 @@ int main(int argc, char ** argv) {
// tokenize the prompt
// find the number of tokens in the prompt
const int n_prompt = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
// allocate space for the tokens and tokenize the prompt
std::vector<llama_token> prompt_tokens(n_prompt);
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__);
return 1;
}
@ -112,7 +113,7 @@ int main(int argc, char ** argv) {
// enable performance counters
ctx_params.no_perf = false;
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@ -131,7 +132,7 @@ int main(int argc, char ** argv) {
for (auto id : prompt_tokens) {
char buf[128];
int n = llama_token_to_piece(model, id, buf, sizeof(buf), 0, true);
int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true);
if (n < 0) {
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
return 1;
@ -164,12 +165,12 @@ int main(int argc, char ** argv) {
new_token_id = llama_sampler_sample(smpl, ctx, -1);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id)) {
if (llama_vocab_is_eog(vocab, new_token_id)) {
break;
}
char buf[128];
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true);
int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
if (n < 0) {
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
return 1;

View file

@ -45,6 +45,8 @@ int main(int argc, char ** argv) {
model_tgt = llama_init_tgt.model.get();
ctx_tgt = llama_init_tgt.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.model;
@ -196,7 +198,7 @@ int main(int argc, char ** argv) {
id_last = ids[i];
if (llama_token_is_eog(model_tgt, id_last)) {
if (llama_vocab_is_eog(vocab, id_last)) {
has_eos = true;
break;
}

View file

@ -90,10 +90,13 @@ int main(int argc, char ** argv) {
model_dft = llama_init_dft.model.get();
ctx_dft = llama_init_dft.context.get();
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft);
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
@ -103,18 +106,18 @@ int main(int argc, char ** argv) {
}
if (
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
) {
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
return 1;
}
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
@ -122,13 +125,13 @@ int main(int argc, char ** argv) {
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i);
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
@ -170,7 +173,7 @@ int main(int argc, char ** argv) {
const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
//GGML_ASSERT(n_vocab == llama_vocab_n_tokens(model_dft));
// how many tokens to draft each time
int n_draft = params.speculative.n_max;
@ -386,7 +389,7 @@ int main(int argc, char ** argv) {
}
}
if (llama_token_is_eog(model_tgt, token_id)) {
if (llama_vocab_is_eog(vocab_tgt, token_id)) {
has_eos = true;
}
++n_predict;

View file

@ -344,8 +344,10 @@ int main(int raw_argc, char ** raw_argv) {
return 1;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_context_params ctx_params = llama_context_default_params();
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) {
fprintf(stderr, "Error: could not create context.\n");
return 1;
@ -365,7 +367,7 @@ int main(int raw_argc, char ** raw_argv) {
prompt = stdin_buffer.str();
}
const bool model_wants_add_bos = llama_add_bos_token(model);
const bool model_wants_add_bos = llama_vocab_get_add_bos(vocab);
const bool add_bos = model_wants_add_bos && !no_bos;
const bool parse_special = !no_parse_special;
const bool escape = !no_escape;
@ -375,7 +377,7 @@ int main(int raw_argc, char ** raw_argv) {
}
std::vector<llama_token> tokens;
tokens = common_tokenize(model, prompt, add_bos, parse_special);
tokens = common_tokenize(vocab, prompt, add_bos, parse_special);
if (printing_ids) {
printf("[");

View file

@ -414,15 +414,15 @@ static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) {
prompt.insert(prompt.end(), tokens.begin(), tokens.end());
}
static void prompt_add(llama_tokens & prompt, const llama_model * model, const std::string & txt, bool add_special, bool parse_special) {
auto tmp = common_tokenize(model, txt, add_special, parse_special);
static void prompt_add(llama_tokens & prompt, const llama_vocab * vocab, const std::string & txt, bool add_special, bool parse_special) {
auto tmp = common_tokenize(vocab, txt, add_special, parse_special);
prompt_add(prompt, tmp);
}
static void prompt_init(llama_tokens & prompt, const llama_model * model) {
static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt.clear();
prompt_add(prompt, model, "<|im_start|>\n", true, true);
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
}
int main(int argc, char ** argv) {
@ -462,6 +462,8 @@ int main(int argc, char ** argv) {
model_ttc = llama_init_ttc.model.get();
ctx_ttc = llama_init_ttc.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model_ttc);
// TODO: refactor in a common struct
params.model = params.vocoder.model;
params.model_url = params.vocoder.model_url;
@ -499,9 +501,9 @@ int main(int argc, char ** argv) {
std::vector<llama_token> prompt_inp;
prompt_init(prompt_inp, model_ttc);
prompt_init(prompt_inp, vocab);
prompt_add(prompt_inp, model_ttc, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
// convert the input text into the necessary format expected by OuteTTS
{
@ -509,10 +511,10 @@ int main(int argc, char ** argv) {
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
prompt_add(prompt_inp, vocab, prompt_clean, false, true);
}
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n", false, true);
prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
// disabled to save time on tokenizing each time
// TODO: load voices from the json files
@ -549,7 +551,7 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
auto tmp = common_tokenize(model_ttc, voice_data, false, true);
auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n");
for (int i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]);
@ -735,9 +737,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
const auto * cands = common_sampler_get_candidates(smpl[i]);
// is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model_ttc, new_token_id) || n_decode == n_predict) {
if (llama_vocab_is_eog(vocab, new_token_id) || n_decode == n_predict) {
std::string reason;
if (llama_token_is_eog(model_ttc, new_token_id)) {
if (llama_vocab_is_eog(vocab, new_token_id)) {
reason = "eos";
} else {
reason = "n_predict";
@ -873,7 +875,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
#if 1
// spectral operations
const int n_embd = llama_n_embd(model_cts);
const int n_embd = llama_model_n_embd(model_cts);
const float * embd = llama_get_embeddings(ctx_cts);
auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads);