common : add llama_batch_add() and llama_batch_clear() helpers

This commit is contained in:
Georgi Gerganov 2023-10-16 12:41:33 +03:00
parent 005949109d
commit 360a333145
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 98 additions and 122 deletions

View file

@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
return cparams;
}
void llama_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos,
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.n_tokens++;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);

View file

@ -70,6 +70,7 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files
// TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter
@ -124,10 +125,23 @@ void process_escapes(std::string& input);
// Model utils
//
// TODO: avoid tuplue, use struct
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
// Batch utils
void llama_batch_clear(struct llama_batch & batch);
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Vocab utils
//

View file

@ -53,6 +53,19 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->cur.clear();
}
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
dst->grammar = nullptr;
}
if (src->grammar) {
dst->grammar = llama_grammar_copy(src->grammar);
}
dst->prev = src->prev;
}
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,

View file

@ -67,6 +67,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);
// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call