llama : add new llama_decode() API that works with llama_batch
This commit is contained in:
parent
58bb5110ca
commit
9f42e75489
13 changed files with 146 additions and 75 deletions
45
llama.h
45
llama.h
|
@ -37,6 +37,8 @@
|
|||
|
||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||
|
||||
#define LLAMA_MAX_RNG_STATE (64*1024)
|
||||
|
||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
|
@ -70,9 +72,20 @@ extern "C" {
|
|||
|
||||
// TODO: not sure about these consts - might just get in the way all the time with no benefit
|
||||
const llama_token * token;
|
||||
const float * embd;
|
||||
const float * embd;
|
||||
const llama_pos * pos;
|
||||
const llama_seq_id * seq_id;
|
||||
|
||||
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
||||
// for future-proof code, use the above fields instead and ignore everything below
|
||||
//
|
||||
// pos[i] = all_pos_0 + i*all_pos_1
|
||||
//
|
||||
llama_pos all_pos_0; // used if pos == NULL
|
||||
llama_pos all_pos_1; // used if pos == NULL
|
||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||
|
||||
bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations
|
||||
} llama_seq;
|
||||
|
||||
enum llama_log_level {
|
||||
|
@ -312,9 +325,6 @@ extern "C" {
|
|||
|
||||
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
|
||||
|
||||
// Sets the current rng seed.
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||
|
||||
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||
// and kv_cache) - will often be smaller after compacting tokens
|
||||
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
|
||||
|
@ -336,19 +346,37 @@ extern "C" {
|
|||
// tokens + n_tokens is the provided batch of new tokens to process
|
||||
// n_past is the number of tokens to use from previous eval calls
|
||||
// Returns 0 on success
|
||||
LLAMA_API int llama_eval(
|
||||
LLAMA_API DEPRECATED(int llama_eval(
|
||||
struct llama_context * ctx,
|
||||
const llama_token * tokens,
|
||||
uint32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
int n_threads),
|
||||
"please use llama_decode() instead");
|
||||
|
||||
// Same as llama_eval, but use float matrix input directly.
|
||||
LLAMA_API int llama_eval_embd(
|
||||
LLAMA_API DEPRECATED(int llama_eval_embd(
|
||||
struct llama_context * ctx,
|
||||
const float * embd,
|
||||
uint32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads),
|
||||
"please use llama_decode() instead");
|
||||
|
||||
// Return batch for single sequence of tokens starting at pos_0
|
||||
// If pos_0 == 0, the clear_kv flag will be auto set to true
|
||||
//
|
||||
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||
//
|
||||
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||
const llama_token * tokens,
|
||||
uint32_t n_tokens,
|
||||
llama_pos pos_0,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
LLAMA_API int llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch,
|
||||
int n_threads);
|
||||
|
||||
// Token logits obtained from the last call to llama_eval()
|
||||
|
@ -434,6 +462,9 @@ extern "C" {
|
|||
// Sampling functions
|
||||
//
|
||||
|
||||
// Sets the current rng seed.
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||
|
||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue