llama : llama_n_vocab() now uses struct llama_vocab

This commit is contained in:
Georgi Gerganov 2025-01-09 15:57:57 +02:00
parent 68db76595e
commit 330bd07b82
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
16 changed files with 55 additions and 39 deletions

View file

@ -949,7 +949,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}
if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_n_vocab(model); i++) {
for (llama_token i = 0; i < llama_n_vocab(vocab); i++) {
if (llama_token_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY});

View file

@ -113,7 +113,10 @@ struct common_sampler {
void set_logits(struct llama_context * ctx, int idx) {
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
cur.resize(n_vocab);
@ -159,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
llama_n_vocab(model),
llama_n_vocab(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
@ -208,7 +211,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));

View file

@ -105,15 +105,15 @@ bool common_speculative_are_compatible(
}
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int n_vocab_tgt = llama_n_vocab(vocab_tgt);
const int n_vocab_dft = llama_n_vocab(vocab_dft);
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
__func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
__func__, n_vocab_tgt, llama_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}

View file

@ -471,7 +471,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_n_vocab(vocab);
const int n_batch = params.n_batch;
int count = 0;

View file

@ -1402,7 +1402,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const int32_t n_vocab = llama_n_vocab(model);
const int32_t n_vocab = llama_n_vocab(vocab);
std::vector<llama_token> tokens(n_batch);
@ -1426,7 +1426,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const int32_t n_vocab = llama_n_vocab(model);
const int32_t n_vocab = llama_n_vocab(vocab);
llama_token token = llama_add_bos_token(vocab) ? llama_token_bos(vocab) : std::rand() % n_vocab;

View file

@ -149,7 +149,7 @@ int main(int argc, char ** argv) {
}
// here we keep adding new n-grams as we go
ngram_container ngrams_observed(llama_n_vocab(model), N, G);
ngram_container ngrams_observed(llama_n_vocab(vocab), N, G);
// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);

View file

@ -341,7 +341,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_n_vocab(vocab);
int count = 0;
double nll = 0.0;
@ -491,7 +491,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_n_vocab(vocab);
int count = 0;
double nll = 0.0;
@ -857,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_n_vocab(vocab);
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1081,6 +1081,8 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
*
*/
static void winogrande_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
constexpr int k_min_trailing_ctx = 3;
@ -1139,7 +1141,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_n_vocab(vocab);
const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1383,6 +1385,8 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic
// https://huggingface.co/datasets/truthful_qa
//
static void multiple_choice_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
std::istringstream strstream(params.prompt);
uint32_t n_task;
@ -1491,7 +1495,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_vocab = llama_n_vocab(vocab);
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1700,8 +1704,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
return;
}
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
if (n_vocab != llama_n_vocab(vocab)) {
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(vocab));
}
std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);

View file

@ -203,10 +203,12 @@ struct server_task {
server_task(server_task_type type) : type(type) {}
static slot_params params_from_json_cmpl(
const llama_model * model,
const llama_context * ctx,
const common_params & params_base,
const json & data) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
slot_params params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
@ -329,7 +331,7 @@ struct server_task {
const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) {
@ -348,7 +350,7 @@ struct server_task {
params.sampling.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = common_tokenize(llama_get_vocab(model), el[0].get<std::string>(), false);
auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
for (auto tok : toks) {
params.sampling.logit_bias.push_back({tok, bias});
}
@ -2079,7 +2081,7 @@ struct server_context {
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.params.sampling.n_probs;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
size_t n_vocab = llama_n_vocab(vocab);
if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size;
@ -3135,7 +3137,7 @@ struct server_context {
json model_meta() const {
return json {
{"vocab_type", llama_vocab_type (vocab)},
{"n_vocab", llama_n_vocab (model)},
{"n_vocab", llama_n_vocab (vocab)},
{"n_ctx_train", llama_n_ctx_train (model)},
{"n_embd", llama_n_embd (model)},
{"n_params", llama_model_n_params(model)},
@ -3654,7 +3656,6 @@ int main(int argc, char ** argv) {
task.prompt_tokens = std::move(tokenized_prompts[i]);
task.params = server_task::params_from_json_cmpl(
ctx_server.model,
ctx_server.ctx,
ctx_server.params_base,
data);

View file

@ -765,14 +765,18 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
return data;
}
static std::string safe_json_to_str(json data) {
static std::string safe_json_to_str(const json & data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace);
}
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
std::vector<llama_token_data> cur;
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {

View file

@ -116,8 +116,8 @@ int main(int argc, char ** argv) {
}
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int n_vocab_tgt = llama_n_vocab(vocab_tgt);
const int n_vocab_dft = llama_n_vocab(vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
@ -125,7 +125,7 @@ int main(int argc, char ** argv) {
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
n_vocab_tgt, llama_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1;
}

View file

@ -449,12 +449,13 @@ extern "C" {
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab);
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
LLAMA_API const struct llama_vocab * llama_get_vocab(const struct llama_model * model);

View file

@ -3691,10 +3691,6 @@ void llama_model_free(struct llama_model * model) {
delete model;
}
int32_t llama_n_vocab(const struct llama_model * model) {
return model->hparams.n_vocab;
}
int32_t llama_n_ctx_train(const struct llama_model * model) {
return model->hparams.n_ctx_train;
}

View file

@ -371,7 +371,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
// TODO: do not allocate each time
std::vector<llama_token_data> cur;

View file

@ -1987,16 +1987,20 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
for (auto id : cache_special_tokens) {
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
}
for (auto token : {"</s>"}) {
for (const auto * token : {"</s>"}) {
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
}
for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
for (const auto * token : {"<unk>", "<s>", "<|endoftext|>"}) {
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
}
}
}
}
int32_t llama_n_vocab(const struct llama_vocab * vocab) {
return vocab->n_vocab();
}
enum llama_vocab_type llama_vocab::get_type() const {
return type;
}

View file

@ -77,7 +77,7 @@ int main(int argc, char **argv) {
atexit([]() { console::cleanup(); });
#endif
const int n_vocab = llama_n_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
for (int i = 0; i < n_vocab; ++i) {
std::string str = common_detokenize(ctx, std::vector<int>(1, i));

View file

@ -65,7 +65,7 @@ int main(int argc, char ** argv) {
atexit([]() { console::cleanup(); });
#endif
const int n_vocab = llama_n_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
for (int i = 0; i < n_vocab; ++i) {
std::string str = common_detokenize(ctx, std::vector<int>(1, i), true);