llama : add new llama_decode() API that works with llama_batch

This commit is contained in:
Georgi Gerganov 2023-09-18 14:23:52 +03:00
parent 58bb5110ca
commit 9f42e75489
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
13 changed files with 146 additions and 75 deletions

45
llama.h
View file

@ -37,6 +37,8 @@
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
#define LLAMA_MAX_RNG_STATE (64*1024)
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
@ -70,9 +72,20 @@ extern "C" {
// TODO: not sure about these consts - might just get in the way all the time with no benefit
const llama_token * token;
const float * embd;
const float * embd;
const llama_pos * pos;
const llama_seq_id * seq_id;
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
//
// pos[i] = all_pos_0 + i*all_pos_1
//
llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL
bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations
} llama_seq;
enum llama_log_level {
@ -312,9 +325,6 @@ extern "C" {
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
// Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
@ -336,19 +346,37 @@ extern "C" {
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
// Returns 0 on success
LLAMA_API int llama_eval(
LLAMA_API DEPRECATED(int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
uint32_t n_tokens,
int n_past,
int n_threads);
int n_threads),
"please use llama_decode() instead");
// Same as llama_eval, but use float matrix input directly.
LLAMA_API int llama_eval_embd(
LLAMA_API DEPRECATED(int llama_eval_embd(
struct llama_context * ctx,
const float * embd,
uint32_t n_tokens,
int n_past,
int n_threads),
"please use llama_decode() instead");
// Return batch for single sequence of tokens starting at pos_0
// If pos_0 == 0, the clear_kv flag will be auto set to true
//
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
//
LLAMA_API struct llama_batch llama_batch_get_one(
const llama_token * tokens,
uint32_t n_tokens,
llama_pos pos_0,
llama_seq_id seq_id);
LLAMA_API int llama_decode(
struct llama_context * ctx,
struct llama_batch batch,
int n_threads);
// Token logits obtained from the last call to llama_eval()
@ -434,6 +462,9 @@ extern "C" {
// Sampling functions
//
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);