llama : improve llama_batch API + simplify parallel example
This commit is contained in:
parent
a1327c71c6
commit
addae65fd4
6 changed files with 111 additions and 70 deletions
32
llama.h
32
llama.h
|
@ -70,11 +70,11 @@ extern "C" {
|
|||
typedef struct llama_batch {
|
||||
uint32_t n_tokens;
|
||||
|
||||
const llama_token * token;
|
||||
const float * embd;
|
||||
const llama_pos * pos;
|
||||
const llama_seq_id * seq_id;
|
||||
const int8_t * logits; // if 0, do not extract logits for that token
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
llama_pos * pos;
|
||||
llama_seq_id * seq_id;
|
||||
int8_t * logits; // if 0, do not extract logits for that token
|
||||
|
||||
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
||||
// for future-proof code, use the above fields instead and ignore everything below
|
||||
|
@ -84,7 +84,7 @@ extern "C" {
|
|||
llama_pos all_pos_0; // used if pos == NULL
|
||||
llama_pos all_pos_1; // used if pos == NULL
|
||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||
} llama_seq;
|
||||
} llama_batch;
|
||||
|
||||
enum llama_log_level {
|
||||
LLAMA_LOG_LEVEL_ERROR = 2,
|
||||
|
@ -366,34 +366,46 @@ extern "C" {
|
|||
// tokens + n_tokens is the provided batch of new tokens to process
|
||||
// n_past is the number of tokens to use from previous eval calls
|
||||
// Returns 0 on success
|
||||
// DEPRECATED: use llama_decode() instead
|
||||
LLAMA_API DEPRECATED(int llama_eval(
|
||||
struct llama_context * ctx,
|
||||
const llama_token * tokens,
|
||||
llama_token * tokens,
|
||||
uint32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads),
|
||||
"please use llama_decode() instead");
|
||||
|
||||
// Same as llama_eval, but use float matrix input directly.
|
||||
// DEPRECATED: use llama_decode() instead
|
||||
LLAMA_API DEPRECATED(int llama_eval_embd(
|
||||
struct llama_context * ctx,
|
||||
const float * embd,
|
||||
float * embd,
|
||||
uint32_t n_tokens,
|
||||
int n_past,
|
||||
int n_threads),
|
||||
"please use llama_decode() instead");
|
||||
|
||||
// Return batch for single sequence of tokens starting at pos_0
|
||||
// If pos_0 == 0, the clear_kv flag will be auto set to true
|
||||
//
|
||||
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||
//
|
||||
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||
const llama_token * tokens,
|
||||
llama_token * tokens,
|
||||
uint32_t n_tokens,
|
||||
llama_pos pos_0,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Allocates a batch of tokens on the heap
|
||||
// The batch needs to be freed with llama_batch_free()
|
||||
// If embd > 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
|
||||
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
|
||||
// The rest of the llama_batch members are allocated with size n_tokens
|
||||
// All members are left uninitialized
|
||||
LLAMA_API struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd);
|
||||
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue