llama.cpp : add llama_get_model

common : add llama_tokenize from model
This commit is contained in:
slaren 2023-09-26 23:23:59 +02:00
parent 0e9ed7f84f
commit 8f5b0eaa8a
4 changed files with 22 additions and 4 deletions

View file

@ -821,16 +821,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
//
std::vector<llama_token> llama_tokenize(
struct llama_context * ctx,
const struct llama_context * ctx,
const std::string & text,
bool add_bos) {
return llama_tokenize(llama_get_model(ctx), text, add_bos);
}
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
n_tokens = llama_tokenize_with_model(model, text.data(), text.length(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
int check = llama_tokenize_with_model(model, text.data(), text.length(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);

View file

@ -143,7 +143,12 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
// tokenizes a string into a vector of tokens
// should work similar to Python's `tokenizer.encode`
std::vector<llama_token> llama_tokenize(
struct llama_context * ctx,
const struct llama_context * ctx,
const std::string & text,
bool add_bos);
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos);

View file

@ -6427,6 +6427,10 @@ void llama_free(struct llama_context * ctx) {
delete ctx;
}
const llama_model * llama_get_model(const struct llama_context * ctx) {
return &ctx->model;
}
int llama_n_vocab(const struct llama_context * ctx) {
return llama_model_n_vocab(&ctx->model);
}

View file

@ -251,6 +251,8 @@ extern "C" {
LLAMA_API bool llama_mmap_supported (void);
LLAMA_API bool llama_mlock_supported(void);
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);