llama : improve llama_batch API + simplify parallel example

This commit is contained in:
Georgi Gerganov 2023-09-20 10:46:18 +03:00
parent a1327c71c6
commit addae65fd4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 111 additions and 70 deletions

32
llama.h
View file

@ -70,11 +70,11 @@ extern "C" {
typedef struct llama_batch {
uint32_t n_tokens;
const llama_token * token;
const float * embd;
const llama_pos * pos;
const llama_seq_id * seq_id;
const int8_t * logits; // if 0, do not extract logits for that token
llama_token * token;
float * embd;
llama_pos * pos;
llama_seq_id * seq_id;
int8_t * logits; // if 0, do not extract logits for that token
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
@ -84,7 +84,7 @@ extern "C" {
llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_seq;
} llama_batch;
enum llama_log_level {
LLAMA_LOG_LEVEL_ERROR = 2,
@ -366,34 +366,46 @@ extern "C" {
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
// Returns 0 on success
// DEPRECATED: use llama_decode() instead
LLAMA_API DEPRECATED(int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
llama_token * tokens,
uint32_t n_tokens,
int n_past,
int n_threads),
"please use llama_decode() instead");
// Same as llama_eval, but use float matrix input directly.
// DEPRECATED: use llama_decode() instead
LLAMA_API DEPRECATED(int llama_eval_embd(
struct llama_context * ctx,
const float * embd,
float * embd,
uint32_t n_tokens,
int n_past,
int n_threads),
"please use llama_decode() instead");
// Return batch for single sequence of tokens starting at pos_0
// If pos_0 == 0, the clear_kv flag will be auto set to true
//
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
//
LLAMA_API struct llama_batch llama_batch_get_one(
const llama_token * tokens,
llama_token * tokens,
uint32_t n_tokens,
llama_pos pos_0,
llama_seq_id seq_id);
// Allocates a batch of tokens on the heap
// The batch needs to be freed with llama_batch_free()
// If embd > 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
// The rest of the llama_batch members are allocated with size n_tokens
// All members are left uninitialized
LLAMA_API struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd);
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)