llama : unified KV cache + batch inference API
This commit is contained in:
parent
fad56936d4
commit
d29e76937c
10 changed files with 315 additions and 236 deletions
34
llama.h
34
llama.h
|
@ -60,7 +60,20 @@ extern "C" {
|
|||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
typedef int llama_token;
|
||||
typedef int32_t llama_pos;
|
||||
typedef int32_t llama_token;
|
||||
typedef int32_t llama_seq_id;
|
||||
|
||||
// data used for batch inference
|
||||
typedef struct llama_batch {
|
||||
uint32_t n_tokens;
|
||||
|
||||
// TODO: not sure about these consts - might just get in the way all the time with no benefit
|
||||
const llama_token * token;
|
||||
const float * embd;
|
||||
const llama_pos * pos;
|
||||
const llama_seq_id * seq_id;
|
||||
} llama_seq;
|
||||
|
||||
enum llama_log_level {
|
||||
LLAMA_LOG_LEVEL_ERROR = 2,
|
||||
|
@ -289,8 +302,15 @@ extern "C" {
|
|||
const char * path_base_model,
|
||||
int n_threads);
|
||||
|
||||
//
|
||||
// KV cache API
|
||||
//
|
||||
|
||||
// Returns the number of tokens in the KV cache
|
||||
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||
"avoid using this, it will be removed in the future");
|
||||
|
||||
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
|
||||
|
||||
// Sets the current rng seed.
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||
|
@ -319,7 +339,7 @@ extern "C" {
|
|||
LLAMA_API int llama_eval(
|
||||
struct llama_context * ctx,
|
||||
const llama_token * tokens,
|
||||
int n_tokens,
|
||||
uint32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
|
@ -327,16 +347,10 @@ extern "C" {
|
|||
LLAMA_API int llama_eval_embd(
|
||||
struct llama_context * ctx,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
uint32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
// Export a static computation graph for context of 511 and batch size of 1
|
||||
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
||||
// parameters here to keep things simple
|
||||
// IMPORTANT: do not use for anything else other than debugging and testing!
|
||||
LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
|
||||
|
||||
// Token logits obtained from the last call to llama_eval()
|
||||
// The logits for the last token are stored in the last row
|
||||
// Can be mutated in order to change the probabilities of the next token
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue