ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-10 15:06:41 +02:00
parent 1586ed5061
commit 1d9f1f2778
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
40 changed files with 134 additions and 96 deletions

View file

@ -886,7 +886,7 @@ struct common_init_result common_init_from_params(common_params & params) {
auto cparams = common_context_params_to_llama(params); auto cparams = common_context_params_to_llama(params);
llama_context * lctx = llama_new_context_with_model(model, cparams); llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) { if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_model_free(model); llama_model_free(model);
@ -900,7 +900,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (!params.control_vectors.empty()) { if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model);
const auto cvec = common_control_vector_load(params.control_vectors); const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) { if (cvec.n_embd == -1) {
@ -949,7 +949,7 @@ struct common_init_result common_init_from_params(common_params & params) {
} }
if (params.sampling.ignore_eos) { if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_n_vocab(vocab); i++) { for (llama_token i = 0; i < llama_vocab_n_vocab(vocab); i++) {
if (llama_token_is_eog(vocab, i)) { if (llama_token_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY}); params.sampling.logit_bias.push_back({i, -INFINITY});

View file

@ -116,7 +116,7 @@ struct common_sampler {
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
cur.resize(n_vocab); cur.resize(n_vocab);
@ -162,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias( llama_sampler_init_logit_bias(
llama_n_vocab(vocab), llama_vocab_n_vocab(vocab),
params.logit_bias.size(), params.logit_bias.size(),
params.logit_bias.data())); params.logit_bias.data()));
@ -177,7 +177,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
c_breakers.push_back(str.c_str()); c_breakers.push_back(str.c_str());
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break; break;
case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TOP_K:
@ -211,7 +211,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) { } else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_vocab(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) { } else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));

View file

@ -105,15 +105,15 @@ bool common_speculative_are_compatible(
} }
{ {
const int n_vocab_tgt = llama_n_vocab(vocab_tgt); const int n_vocab_tgt = llama_vocab_n_vocab(vocab_tgt);
const int n_vocab_dft = llama_n_vocab(vocab_dft); const int n_vocab_dft = llama_vocab_n_vocab(vocab_dft);
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
__func__, n_vocab_tgt, llama_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); __func__, n_vocab_tgt, llama_vocab_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false; return false;
} }

View file

@ -50,7 +50,7 @@ int main(int argc, char ** argv) {
// ensure enough sequences are available // ensure enough sequences are available
ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end()); ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end());
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);

View file

@ -64,7 +64,7 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_kv_req; ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_predict, n_parallel); ctx_params.n_batch = std::max(n_predict, n_parallel);
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
auto sparams = llama_sampler_chain_default_params(); auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false; sparams.no_perf = false;

View file

@ -911,7 +911,7 @@ int main(int argc, char ** argv) {
load_vocab(params.fn_vocab_model, &config, &vocab); load_vocab(params.fn_vocab_model, &config, &vocab);
struct my_llama_model model; struct my_llama_model model;
model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx); model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx);
model.hparams.n_ctx = params.n_ctx; model.hparams.n_ctx = params.n_ctx;
model.hparams.n_embd = config.dim; //params.n_embd; model.hparams.n_embd = config.dim; //params.n_embd;
model.hparams.n_ff = config.hidden_dim; model.hparams.n_ff = config.hidden_dim;

View file

@ -423,8 +423,8 @@ int main(int argc, char ** argv) {
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
// int n_ctx = llama_n_ctx(ctx); // int n_ctx = llama_n_ctx(ctx);
int n_layers = llama_n_layer(model); int n_layers = llama_model_n_layer(model);
int n_embd = llama_n_embd(model); int n_embd = llama_model_n_embd(model);
// get model hint param (a.k.a model arch name) // get model hint param (a.k.a model arch name)
char model_hint[128]; char model_hint[128];

View file

@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
@ -183,7 +183,7 @@ int main(int argc, char ** argv) {
} }
// allocate output // allocate output
const int n_embd = llama_n_embd(model); const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_embd_count * n_embd, 0); std::vector<float> embeddings(n_embd_count * n_embd, 0);
float * emb = embeddings.data(); float * emb = embeddings.data();

View file

@ -53,7 +53,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
llama_decode(ctx, batch); llama_decode(ctx, batch);
// get embedding dimensions // get embedding dimensions
uint64_t n_embd = llama_n_embd(model); uint64_t n_embd = llama_model_n_embd(model);
// allocate embedding output // allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f); std::vector<float> emb_unorm(n_embd, 0.0f);
@ -171,7 +171,7 @@ int main(int argc, char * argv[]) {
llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams); llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams);
// create generation context // create generation context
llama_context * ctx = llama_new_context_with_model(model, cparams); llama_context * ctx = llama_init_from_model(model, cparams);
auto sparams = llama_sampler_chain_default_params(); auto sparams = llama_sampler_chain_default_params();
@ -200,7 +200,7 @@ int main(int argc, char * argv[]) {
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
const int n_embd = llama_n_embd(model); const int n_embd = llama_model_n_embd(model);
const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);

View file

@ -7,7 +7,6 @@
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <ctime> #include <ctime>
#include <sstream>
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
@ -40,7 +39,7 @@ public:
void set_params(common_params params) { m_params = std::move(params); } void set_params(common_params params) { m_params = std::move(params); }
bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data);
void save_imatrix(int ncall = -1) const; void save_imatrix(int ncall = -1) const;
bool load_imatrix(const char * file_name); bool load_imatrix(const char * fname);
private: private:
std::unordered_map<std::string, Stats> m_stats; std::unordered_map<std::string, Stats> m_stats;
common_params m_params; common_params m_params;
@ -471,7 +470,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const int n_chunk_max = tokens.size() / n_ctx; const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
int count = 0; int count = 0;
@ -630,7 +629,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) { if (params.n_ctx > n_ctx_train) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx); __func__, n_ctx_train, params.n_ctx);

View file

@ -141,7 +141,7 @@ int main(int argc, char ** argv) {
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
LOG_DBG("n_ctx: %d\n", n_ctx); LOG_DBG("n_ctx: %d\n", n_ctx);

View file

@ -1402,7 +1402,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_n_vocab(vocab); const int32_t n_vocab = llama_vocab_n_vocab(vocab);
std::vector<llama_token> tokens(n_batch); std::vector<llama_token> tokens(n_batch);
@ -1426,7 +1426,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_n_vocab(vocab); const int32_t n_vocab = llama_vocab_n_vocab(vocab);
llama_token token = llama_add_bos_token(vocab) ? llama_token_bos(vocab) : std::rand() % n_vocab; llama_token token = llama_add_bos_token(vocab) ? llama_token_bos(vocab) : std::rand() % n_vocab;
@ -1539,7 +1539,7 @@ int main(int argc, char ** argv) {
prev_inst = &inst; prev_inst = &inst;
} }
llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams()); llama_context * ctx = llama_init_from_model(lmodel, inst.to_llama_cparams());
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str()); fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str());
llama_model_free(lmodel); llama_model_free(lmodel);

View file

@ -243,11 +243,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
llama_context_params ctx_params = common_context_params_to_llama(*params); llama_context_params ctx_params = common_context_params_to_llama(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
if (ctx_llama == NULL) { if (ctx_llama == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__); LOG_ERR("%s: failed to create the llama_context\n" , __func__);

View file

@ -384,7 +384,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
// make sure that the correct mmproj was used, i.e., compare apples to apples // make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama));
auto n_image_embd = clip_n_mmproj_embd(ctx_clip); auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
if (n_image_embd != n_llama_embd) { if (n_image_embd != n_llama_embd) {
LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
@ -456,7 +456,7 @@ struct llava_embd_batch {
}; };
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
int n_eval = image_embed->n_image_pos - i; int n_eval = image_embed->n_image_pos - i;

View file

@ -54,7 +54,7 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
ctx_params.n_ctx = params->n_ctx; ctx_params.n_ctx = params->n_ctx;
} }
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
if (ctx_llama == NULL) { if (ctx_llama == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__); LOG_ERR("%s: failed to create the llama_context\n" , __func__);

View file

@ -27,7 +27,7 @@
static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) { int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
const int patch_size = 14 * 2; const int patch_size = 14 * 2;
const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0); const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0); const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
@ -332,11 +332,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
llama_context_params ctx_params = common_context_params_to_llama(*params); llama_context_params ctx_params = common_context_params_to_llama(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
if (ctx_llama == NULL) { if (ctx_llama == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__); LOG_ERR("%s: failed to create the llama_context\n" , __func__);
@ -485,7 +484,7 @@ static void debug_test_mrope_2d() {
} }
static void debug_dump_img_embed(struct llava_context * ctx_llava) { static void debug_dump_img_embed(struct llava_context * ctx_llava) {
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama)); int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
int ne = n_embd * 4; int ne = n_embd * 4;
float vals[56 * 56 * 3]; float vals[56 * 56 * 3];
// float embd[ne]; // float embd[ne];

View file

@ -149,7 +149,7 @@ int main(int argc, char ** argv) {
} }
// here we keep adding new n-grams as we go // here we keep adding new n-grams as we go
ngram_container ngrams_observed(llama_n_vocab(vocab), N, G); ngram_container ngrams_observed(llama_vocab_n_vocab(vocab), N, G);
// debug // debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1); struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);

View file

@ -5,7 +5,6 @@
#include "sampling.h" #include "sampling.h"
#include "llama.h" #include "llama.h"
#include <cassert>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <ctime> #include <ctime>
@ -198,7 +197,7 @@ int main(int argc, char ** argv) {
llama_attach_threadpool(ctx, threadpool, threadpool_batch); llama_attach_threadpool(ctx, threadpool, threadpool_batch);
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
if (n_ctx > n_ctx_train) { if (n_ctx > n_ctx_train) {

View file

@ -76,11 +76,11 @@ int main(int argc, char ** argv) {
llama_context_params ctx_params = common_context_params_to_llama(params); llama_context_params ctx_params = common_context_params_to_llama(params);
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; ctx_params.n_ctx = llama_model_n_ctx_train(model)*n_grp + n_keep;
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) { if (ctx == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__); LOG_ERR("%s: failed to create the llama_context\n" , __func__);
return 1; return 1;

View file

@ -341,7 +341,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
@ -491,7 +491,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
@ -857,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1141,7 +1141,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
const int max_tasks_per_batch = 128; const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1495,7 +1495,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1704,8 +1704,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str()); LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
return; return;
} }
if (n_vocab != llama_n_vocab(vocab)) { if (n_vocab != llama_vocab_n_vocab(vocab)) {
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(vocab)); LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_vocab_n_vocab(vocab));
} }
std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk); std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
@ -2011,7 +2011,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) { if (params.n_ctx > n_ctx_train) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",

View file

@ -319,7 +319,7 @@ int main(int argc, char ** argv) {
auto cparams = llama_context_default_params(); auto cparams = llama_context_default_params();
cparams.n_ctx = 256; cparams.n_ctx = 256;
ctx = llama_new_context_with_model(model, cparams); ctx = llama_init_from_model(model, cparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());

View file

@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
@ -217,7 +217,7 @@ int main(int argc, char ** argv) {
struct llama_batch batch = llama_batch_init(n_batch, 0, 1); struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output // allocate output
const int n_embd = llama_n_embd(model); const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_chunks * n_embd, 0); std::vector<float> embeddings(n_chunks * n_embd, 0);
float * emb = embeddings.data(); float * emb = embeddings.data();

View file

@ -685,7 +685,7 @@ class LlamaData {
// Initializes the context with the specified parameters // Initializes the context with the specified parameters
llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
llama_context_ptr context(llama_new_context_with_model(model.get(), opt.ctx_params)); llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params));
if (!context) { if (!context) {
printe("%s: error: failed to create the llama_context\n", __func__); printe("%s: error: failed to create the llama_context\n", __func__);
} }

View file

@ -97,7 +97,7 @@ int main(int argc, char ** argv) {
printf("\n\n"); printf("\n\n");
// make new context // make new context
llama_context * ctx2 = llama_new_context_with_model(model, common_context_params_to_llama(params)); llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params));
llama_sampler * smpl2 = llama_sampler_chain_init(sparams); llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
@ -154,7 +154,7 @@ int main(int argc, char ** argv) {
} }
// make new context // make new context
llama_context * ctx3 = llama_new_context_with_model(model, common_context_params_to_llama(params)); llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
llama_sampler * smpl3 = llama_sampler_chain_init(sparams); llama_sampler * smpl3 = llama_sampler_chain_init(sparams);

View file

@ -331,7 +331,7 @@ struct server_task {
const auto & logit_bias = data.find("logit_bias"); const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) { if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
for (const auto & el : *logit_bias) { for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect // TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) { if (el.is_array() && el.size() == 2) {
@ -1763,7 +1763,7 @@ struct server_context {
if (model_dft) { if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) { if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n"); SRV_ERR("%s", "failed to create draft context\n");
return; return;
@ -2061,7 +2061,7 @@ struct server_context {
SLT_DBG(slot, "%s", "stopped by EOS\n"); SLT_DBG(slot, "%s", "stopped by EOS\n");
} }
const auto n_ctx_train = llama_n_ctx_train(model); const auto n_ctx_train = llama_model_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
slot.truncated = true; slot.truncated = true;
@ -2081,7 +2081,7 @@ struct server_context {
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.params.sampling.n_probs; size_t n_probs = slot.params.sampling.n_probs;
size_t n_vocab = llama_n_vocab(vocab); size_t n_vocab = llama_vocab_n_vocab(vocab);
if (post_sampling) { if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl); const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size; const size_t max_probs = cur_p->size;
@ -2232,7 +2232,7 @@ struct server_context {
res->n_tokens = slot.n_prompt_tokens; res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
const int n_embd = llama_n_embd(model); const int n_embd = llama_model_n_embd(model);
std::vector<float> embd_res(n_embd, 0.0f); std::vector<float> embd_res(n_embd, 0.0f);
@ -3137,10 +3137,10 @@ struct server_context {
json model_meta() const { json model_meta() const {
return json { return json {
{"vocab_type", llama_vocab_type (vocab)}, {"vocab_type", llama_vocab_type (vocab)},
{"n_vocab", llama_n_vocab (vocab)}, {"n_vocab", llama_vocab_n_vocab (vocab)},
{"n_ctx_train", llama_n_ctx_train (model)}, {"n_ctx_train", llama_model_n_ctx_train(model)},
{"n_embd", llama_n_embd (model)}, {"n_embd", llama_model_n_embd (model)},
{"n_params", llama_model_n_params(model)}, {"n_params", llama_model_n_params (model)},
{"size", llama_model_size (model)}, {"size", llama_model_size (model)},
}; };
} }

View file

@ -776,7 +776,7 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
cur.resize(n_vocab); cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) { for (llama_token token_id = 0; token_id < n_vocab; token_id++) {

View file

@ -82,7 +82,7 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_ctx; ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = n_ctx; ctx_params.n_batch = n_ctx;
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) { if (!ctx) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1; return 1;

View file

@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
// enable performance counters // enable performance counters
ctx_params.no_perf = false; ctx_params.no_perf = false;
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);

View file

@ -116,8 +116,8 @@ int main(int argc, char ** argv) {
} }
{ {
const int n_vocab_tgt = llama_n_vocab(vocab_tgt); const int n_vocab_tgt = llama_vocab_n_vocab(vocab_tgt);
const int n_vocab_dft = llama_n_vocab(vocab_dft); const int n_vocab_dft = llama_vocab_n_vocab(vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft ? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt; : n_vocab_dft - n_vocab_tgt;
@ -125,7 +125,7 @@ int main(int argc, char ** argv) {
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); n_vocab_tgt, llama_vocab_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1; return 1;
} }
@ -173,7 +173,7 @@ int main(int argc, char ** argv) {
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab // the 2 models should have the same vocab
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); //GGML_ASSERT(n_vocab == llama_vocab_n_vocab(model_dft));
// how many tokens to draft each time // how many tokens to draft each time
int n_draft = params.speculative.n_max; int n_draft = params.speculative.n_max;

View file

@ -347,7 +347,7 @@ int main(int raw_argc, char ** raw_argv) {
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) { if (!ctx) {
fprintf(stderr, "Error: could not create context.\n"); fprintf(stderr, "Error: could not create context.\n");
return 1; return 1;

View file

@ -875,7 +875,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
#if 1 #if 1
// spectral operations // spectral operations
const int n_embd = llama_n_embd(model_cts); const int n_embd = llama_model_n_embd(model_cts);
const float * embd = llama_get_embeddings(ctx_cts); const float * embd = llama_get_embeddings(ctx_cts);
auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads); auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads);

View file

@ -427,11 +427,15 @@ extern "C" {
LLAMA_API void llama_model_free(struct llama_model * model); LLAMA_API void llama_model_free(struct llama_model * model);
// TODO: rename to llama_init_from_model LLAMA_API struct llama_context * llama_init_from_model(
LLAMA_API struct llama_context * llama_new_context_with_model(
struct llama_model * model, struct llama_model * model,
struct llama_context_params params); struct llama_context_params params);
DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model(
struct llama_model * model,
struct llama_context_params params),
"use llama_init_from_model instead");
// Frees all allocated memory // Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx); LLAMA_API void llama_free(struct llama_context * ctx);
@ -449,12 +453,12 @@ extern "C" {
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
LLAMA_API int32_t llama_n_embd (const struct llama_model * model); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
LLAMA_API int32_t llama_n_layer (const struct llama_model * model); DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead");
LLAMA_API int32_t llama_n_head (const struct llama_model * model); DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead");
LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab); DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_vocab instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
@ -462,11 +466,18 @@ extern "C" {
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor // Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
LLAMA_API int32_t llama_vocab_n_vocab(const struct llama_vocab * vocab);
// Functions to access the model's GGUF metadata scalar values // Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure // - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure // - The output string is always null-terminated and cleared on failure

View file

@ -3747,22 +3747,42 @@ void llama_model_free(struct llama_model * model) {
delete model; delete model;
} }
int32_t llama_n_ctx_train(const struct llama_model * model) { int32_t llama_model_n_ctx_train(const struct llama_model * model) {
return model->hparams.n_ctx_train; return model->hparams.n_ctx_train;
} }
int32_t llama_n_embd(const struct llama_model * model) { int32_t llama_model_n_embd(const struct llama_model * model) {
return model->hparams.n_embd; return model->hparams.n_embd;
} }
int32_t llama_n_layer(const struct llama_model * model) { int32_t llama_model_n_layer(const struct llama_model * model) {
return model->hparams.n_layer; return model->hparams.n_layer;
} }
int32_t llama_n_head(const struct llama_model * model) { int32_t llama_model_n_head(const struct llama_model * model) {
return model->hparams.n_head(); return model->hparams.n_head();
} }
// deprecated
int32_t llama_n_ctx_train(const struct llama_model * model) {
return llama_model_n_ctx_train(model);
}
// deprecated
int32_t llama_n_embd(const struct llama_model * model) {
return llama_model_n_embd(model);
}
// deprecated
int32_t llama_n_layer(const struct llama_model * model) {
return llama_model_n_layer(model);
}
// deprecated
int32_t llama_n_head(const struct llama_model * model) {
return llama_model_n_head(model);
}
enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
switch (model->arch) { switch (model->arch) {
// these models do not use RoPE // these models do not use RoPE

View file

@ -374,7 +374,7 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
// TODO: do not allocate each time // TODO: do not allocate each time
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;

View file

@ -3026,10 +3026,15 @@ void llama_vocab::print_info() const {
// interface implementation // interface implementation
// //
int32_t llama_n_vocab(const struct llama_vocab * vocab) { int32_t llama_vocab_n_vocab(const struct llama_vocab * vocab) {
return vocab->n_vocab(); return vocab->n_vocab();
} }
// deprecated
int32_t llama_n_vocab(const struct llama_vocab * vocab) {
return llama_vocab_n_vocab(vocab);
}
enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) { enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) {
return vocab->get_type(); return vocab->get_type();
} }

View file

@ -9502,7 +9502,7 @@ struct llama_model * llama_model_load_from_file(
return model; return model;
} }
struct llama_context * llama_new_context_with_model( struct llama_context * llama_init_from_model(
struct llama_model * model, struct llama_model * model,
struct llama_context_params params) { struct llama_context_params params) {
@ -9853,6 +9853,12 @@ struct llama_context * llama_new_context_with_model(
return ctx; return ctx;
} }
struct llama_context * llama_new_context_with_model(
struct llama_model * model,
struct llama_context_params params) {
return llama_init_from_model(model, params);
}
// //
// kv cache // kv cache
// //

View file

@ -14,7 +14,7 @@ int main(int argc, char ** argv) {
std::thread([&model_path]() { std::thread([&model_path]() {
llama_backend_init(); llama_backend_init();
auto * model = llama_model_load_from_file(model_path, llama_model_default_params()); auto * model = llama_model_load_from_file(model_path, llama_model_default_params());
auto * ctx = llama_new_context_with_model(model, llama_context_default_params()); auto * ctx = llama_init_from_model(model, llama_context_default_params());
llama_free(ctx); llama_free(ctx);
llama_model_free(model); llama_model_free(model);
llama_backend_free(); llama_backend_free();

View file

@ -161,7 +161,7 @@ int main(int argc, char **argv) {
auto cparams = llama_context_default_params(); auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams); ctx = llama_init_from_model(model, cparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());

View file

@ -55,7 +55,7 @@ int main(int argc, char **argv) {
auto cparams = llama_context_default_params(); auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams); ctx = llama_init_from_model(model, cparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@ -77,7 +77,7 @@ int main(int argc, char **argv) {
atexit([]() { console::cleanup(); }); atexit([]() { console::cleanup(); });
#endif #endif
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
for (int i = 0; i < n_vocab; ++i) { for (int i = 0; i < n_vocab; ++i) {
std::string str = common_detokenize(ctx, std::vector<int>(1, i)); std::string str = common_detokenize(ctx, std::vector<int>(1, i));

View file

@ -43,7 +43,7 @@ int main(int argc, char ** argv) {
auto cparams = llama_context_default_params(); auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams); ctx = llama_init_from_model(model, cparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@ -65,7 +65,7 @@ int main(int argc, char ** argv) {
atexit([]() { console::cleanup(); }); atexit([]() { console::cleanup(); });
#endif #endif
const int n_vocab = llama_n_vocab(vocab); const int n_vocab = llama_vocab_n_vocab(vocab);
for (int i = 0; i < n_vocab; ++i) { for (int i = 0; i < n_vocab; ++i) {
std::string str = common_detokenize(ctx, std::vector<int>(1, i), true); std::string str = common_detokenize(ctx, std::vector<int>(1, i), true);