llama : add llama_vocab
, functions -> methods, naming (#11110)
* llama : functions -> methods (#11110) * llama : add struct llama_vocab to the API (#11156) ggml-ci * hparams : move vocab params to llama_vocab (#11159) ggml-ci * vocab : more pimpl (#11165) ggml-ci * vocab : minor tokenization optimizations (#11160) ggml-ci Co-authored-by: Diego Devesa <slarengh@gmail.com> * lora : update API names (#11167) ggml-ci * llama : update API names to use correct prefix (#11174) * llama : update API names to use correct prefix ggml-ci * cont ggml-ci * cont ggml-ci * minor [no ci] * vocab : llama_vocab_add_[be]os -> llama_vocab_get_add_[be]os (#11174) ggml-ci * vocab : llama_vocab_n_vocab -> llama_vocab_n_tokens (#11174) ggml-ci --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
c05e8c9934
commit
afa8a9ec9b
68 changed files with 5855 additions and 5400 deletions
|
@ -20,11 +20,11 @@ struct llama_sampler_deleter {
|
|||
void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); }
|
||||
};
|
||||
|
||||
struct llama_lora_adapter_deleter {
|
||||
void operator()(llama_lora_adapter * lora_adapter) { llama_lora_adapter_free(lora_adapter); }
|
||||
struct llama_adapter_lora_deleter {
|
||||
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
|
||||
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
|
||||
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
|
||||
typedef std::unique_ptr<llama_lora_adapter, llama_lora_adapter_deleter> llama_lora_adapter_ptr;
|
||||
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
|
||||
|
|
172
include/llama.h
172
include/llama.h
|
@ -56,7 +56,7 @@ extern "C" {
|
|||
// TODO: show sample usage
|
||||
//
|
||||
|
||||
// struct llama_vocab; // TODO: add in the future
|
||||
struct llama_vocab;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
struct llama_sampler;
|
||||
|
@ -385,8 +385,7 @@ extern "C" {
|
|||
} llama_chat_message;
|
||||
|
||||
// lora adapter
|
||||
// TODO: rename to llama_adapter_lora
|
||||
struct llama_lora_adapter;
|
||||
struct llama_adapter_lora;
|
||||
|
||||
// Helpers for getting default parameters
|
||||
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
|
||||
|
@ -400,18 +399,19 @@ extern "C" {
|
|||
// Call once at the start of the program
|
||||
LLAMA_API void llama_backend_init(void);
|
||||
|
||||
// Call once at the end of the program - currently only used for MPI
|
||||
LLAMA_API void llama_backend_free(void);
|
||||
|
||||
//optional:
|
||||
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
|
||||
|
||||
// Optional: an auto threadpool gets created in ggml if not passed explicitly
|
||||
LLAMA_API void llama_attach_threadpool(
|
||||
struct llama_context * ctx,
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
|
||||
struct llama_context * ctx,
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
|
||||
// Call once at the end of the program - currently only used for MPI
|
||||
LLAMA_API void llama_backend_free(void);
|
||||
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file(
|
||||
const char * path_model,
|
||||
|
@ -427,11 +427,15 @@ extern "C" {
|
|||
|
||||
LLAMA_API void llama_model_free(struct llama_model * model);
|
||||
|
||||
// TODO: rename to llama_init_from_model
|
||||
LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||
LLAMA_API struct llama_context * llama_init_from_model(
|
||||
struct llama_model * model,
|
||||
struct llama_context_params params);
|
||||
|
||||
DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||
struct llama_model * model,
|
||||
struct llama_context_params params),
|
||||
"use llama_init_from_model instead");
|
||||
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
|
@ -449,20 +453,30 @@ extern "C" {
|
|||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_head (const struct llama_model * model);
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead");
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead");
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
||||
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
|
||||
|
||||
// Functions to access the model's GGUF metadata scalar values
|
||||
// - The functions return the length of the string on success, or -1 on failure
|
||||
|
@ -488,6 +502,9 @@ extern "C" {
|
|||
// Returns the total size of all the tensors in the model in bytes
|
||||
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||
|
||||
// Get the default chat template. Returns nullptr if not available
|
||||
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
|
||||
|
||||
// Returns the total number of parameters in the model
|
||||
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
||||
|
||||
|
@ -515,34 +532,31 @@ extern "C" {
|
|||
//
|
||||
|
||||
// Load a LoRA adapter from file
|
||||
// TODO: rename to llama_adapter_lora_init
|
||||
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
|
||||
LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init(
|
||||
struct llama_model * model,
|
||||
const char * path_lora);
|
||||
|
||||
// Manually free a LoRA adapter
|
||||
// Note: loaded adapters will be free when the associated model is deleted
|
||||
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
||||
|
||||
// The following functions operate on a llama_context, hence the naming: llama_verb_...
|
||||
|
||||
// Add a loaded LoRA adapter to given context
|
||||
// This will not modify model's weight
|
||||
// TODO: rename to llama_set_adapter_lora
|
||||
LLAMA_API int32_t llama_lora_adapter_set(
|
||||
LLAMA_API int32_t llama_set_adapter_lora(
|
||||
struct llama_context * ctx,
|
||||
struct llama_lora_adapter * adapter,
|
||||
struct llama_adapter_lora * adapter,
|
||||
float scale);
|
||||
|
||||
// Remove a specific LoRA adapter from given context
|
||||
// Return -1 if the adapter is not present in the context
|
||||
// TODO: rename to llama_rm_adapter_lora
|
||||
LLAMA_API int32_t llama_lora_adapter_remove(
|
||||
LLAMA_API int32_t llama_rm_adapter_lora(
|
||||
struct llama_context * ctx,
|
||||
struct llama_lora_adapter * adapter);
|
||||
struct llama_adapter_lora * adapter);
|
||||
|
||||
// Remove all LoRA adapters from given context
|
||||
// TODO: rename to llama_clear_adapter_lora
|
||||
LLAMA_API void llama_lora_adapter_clear(struct llama_context * ctx);
|
||||
|
||||
// Manually free a LoRA adapter
|
||||
// Note: loaded adapters will be free when the associated model is deleted
|
||||
// TODO: rename to llama_adapter_lora_free
|
||||
LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
|
||||
LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx);
|
||||
|
||||
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
|
||||
// the currently loaded vector.
|
||||
|
@ -550,9 +564,8 @@ extern "C" {
|
|||
// to an n_embd x n_layers buffer starting from layer 1.
|
||||
// il_start and il_end are the layer range the vector should apply to (both inclusive)
|
||||
// See llama_control_vector_load in common to load a control vector.
|
||||
// TODO: rename to llama_adapter_cvec_apply
|
||||
LLAMA_API int32_t llama_control_vector_apply(
|
||||
struct llama_context * lctx,
|
||||
LLAMA_API int32_t llama_apply_adapter_cvec(
|
||||
struct llama_context * ctx,
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
|
@ -908,41 +921,57 @@ extern "C" {
|
|||
// Vocab
|
||||
//
|
||||
|
||||
LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
|
||||
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
// Identify if Token Id is a control token or a render-able token
|
||||
LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
// Special tokens
|
||||
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
|
||||
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
||||
LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
|
||||
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
|
||||
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
|
||||
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
||||
LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
|
||||
LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence
|
||||
LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence
|
||||
LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn
|
||||
LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab); // classification
|
||||
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
|
||||
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
|
||||
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
|
||||
|
||||
LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
|
||||
LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
|
||||
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
||||
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
||||
|
||||
// infill tokens
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
|
||||
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
|
||||
DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocabable_get_text instead");
|
||||
DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead");
|
||||
DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead");
|
||||
DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead");
|
||||
DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead");
|
||||
DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead");
|
||||
DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead");
|
||||
|
||||
//
|
||||
// Tokenization
|
||||
|
@ -958,7 +987,7 @@ extern "C" {
|
|||
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
||||
/// as plaintext. Does not insert a leading space.
|
||||
LLAMA_API int32_t llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
const char * text,
|
||||
int32_t text_len,
|
||||
llama_token * tokens,
|
||||
|
@ -972,7 +1001,7 @@ extern "C" {
|
|||
// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
|
||||
// @param special If true, special tokens are rendered in the output.
|
||||
LLAMA_API int32_t llama_token_to_piece(
|
||||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
llama_token token,
|
||||
char * buf,
|
||||
int32_t length,
|
||||
|
@ -986,7 +1015,7 @@ extern "C" {
|
|||
/// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
|
||||
/// @param unparse_special If true, special tokens are rendered in the output.
|
||||
LLAMA_API int32_t llama_detokenize(
|
||||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
char * text,
|
||||
|
@ -1009,7 +1038,6 @@ extern "C" {
|
|||
/// @param length The size of the allocated buffer
|
||||
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
|
||||
LLAMA_API int32_t llama_chat_apply_template(
|
||||
const struct llama_model * model,
|
||||
const char * tmpl,
|
||||
const struct llama_chat_message * chat,
|
||||
size_t n_msg,
|
||||
|
@ -1057,7 +1085,6 @@ extern "C" {
|
|||
// llama_sampler_free(smpl);
|
||||
//
|
||||
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
|
||||
//
|
||||
|
||||
typedef void * llama_sampler_context_t;
|
||||
|
@ -1157,7 +1184,7 @@ extern "C" {
|
|||
float eta);
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
||||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root);
|
||||
|
||||
|
@ -1169,8 +1196,9 @@ extern "C" {
|
|||
float penalty_present); // 0.0 = disabled
|
||||
|
||||
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
||||
const struct llama_model * model,
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
||||
const struct llama_vocab * vocab,
|
||||
int32_t n_ctx_train,
|
||||
float dry_multiplier,
|
||||
float dry_base,
|
||||
int32_t dry_allowed_length,
|
||||
|
@ -1204,7 +1232,7 @@ extern "C" {
|
|||
// 3. discard non-EOG tokens with low prob
|
||||
// 4. if no tokens are left -> pick EOT
|
||||
//
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);
|
||||
|
||||
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue