llama : add llama_vocab
, functions -> methods, naming (#11110)
* llama : functions -> methods (#11110) * llama : add struct llama_vocab to the API (#11156) ggml-ci * hparams : move vocab params to llama_vocab (#11159) ggml-ci * vocab : more pimpl (#11165) ggml-ci * vocab : minor tokenization optimizations (#11160) ggml-ci Co-authored-by: Diego Devesa <slarengh@gmail.com> * lora : update API names (#11167) ggml-ci * llama : update API names to use correct prefix (#11174) * llama : update API names to use correct prefix ggml-ci * cont ggml-ci * cont ggml-ci * minor [no ci] * vocab : llama_vocab_add_[be]os -> llama_vocab_get_add_[be]os (#11174) ggml-ci * vocab : llama_vocab_n_vocab -> llama_vocab_n_tokens (#11174) ggml-ci --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
c05e8c9934
commit
afa8a9ec9b
68 changed files with 5855 additions and 5400 deletions
|
@ -296,8 +296,11 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
|
||||
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
||||
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
|
||||
|
||||
LOG_INF("%s: tokenizing the input ..\n", __func__);
|
||||
|
||||
|
@ -338,7 +341,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
||||
const int n_batch = params.n_batch;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
int count = 0;
|
||||
double nll = 0.0;
|
||||
|
@ -382,7 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||
tokens[batch_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
|
@ -444,8 +447,11 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
|
||||
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
||||
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
|
||||
|
||||
std::ofstream logits_stream;
|
||||
if (!params.logits_file.empty()) {
|
||||
|
@ -485,7 +491,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
||||
const int n_batch = params.n_batch;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
int count = 0;
|
||||
double nll = 0.0;
|
||||
|
@ -557,7 +563,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
|
||||
tokens[seq_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
for (int k = 0; k < batch_size; ++k) {
|
||||
|
@ -732,6 +738,9 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
|
|||
}
|
||||
|
||||
static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// Calculates hellaswag score (acc_norm) from prompt
|
||||
//
|
||||
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
|
||||
|
@ -765,7 +774,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||
size_t hs_task_count = prompt_lines.size()/6;
|
||||
LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
|
||||
|
||||
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
|
||||
const bool is_spm = llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM;
|
||||
LOG_INF("================================= is_spm = %d\n", is_spm);
|
||||
|
||||
// The tasks should be randomized so the score stabilizes quickly.
|
||||
|
@ -848,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_batch = params.n_batch;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
const int max_tasks_per_batch = 32;
|
||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||
|
@ -1072,6 +1081,8 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
|
|||
*
|
||||
*/
|
||||
static void winogrande_score(llama_context * ctx, const common_params & params) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
constexpr int k_min_trailing_ctx = 3;
|
||||
|
||||
|
@ -1130,7 +1141,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
|||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_batch = params.n_batch;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
const int max_tasks_per_batch = 128;
|
||||
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||
|
@ -1374,6 +1385,8 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic
|
|||
// https://huggingface.co/datasets/truthful_qa
|
||||
//
|
||||
static void multiple_choice_score(llama_context * ctx, const common_params & params) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
std::istringstream strstream(params.prompt);
|
||||
uint32_t n_task;
|
||||
|
@ -1482,7 +1495,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_batch = params.n_batch;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
const int max_tasks_per_batch = 32;
|
||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||
|
@ -1655,6 +1668,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||
}
|
||||
|
||||
static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.logits_file.empty()) {
|
||||
LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
|
||||
return;
|
||||
|
@ -1688,8 +1704,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||
LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
|
||||
return;
|
||||
}
|
||||
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
|
||||
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
|
||||
if (n_vocab != llama_vocab_n_tokens(vocab)) {
|
||||
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_vocab_n_tokens(vocab));
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
|
||||
|
@ -1701,8 +1717,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||
const int n_batch = params.n_batch;
|
||||
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
||||
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
|
||||
|
||||
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
|
||||
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
|
||||
|
@ -1761,7 +1777,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||
tokens[batch_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
@ -1995,7 +2011,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
const int n_ctx_train = llama_n_ctx_train(model);
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||
|
||||
if (params.n_ctx > n_ctx_train) {
|
||||
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue