llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch (#9745)

* refactor llama_batch_get_one

* adapt all examples

* fix simple.cpp

* fix llama_bench

* fix

* fix context shifting

* free batch before return

* use common_batch_add, reuse llama_batch in loop

* null terminated seq_id list

* fix save-load-state example

* fix perplexity

* correct token pos in llama_batch_allocr
This commit is contained in:
Xuan Son Nguyen 2024-10-18 23:18:01 +02:00 committed by GitHub
parent afd9909a64
commit cda0e4b648
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 205 additions and 118 deletions

View file

@ -232,8 +232,11 @@ extern "C" {
// - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// (if set to NULL, the token position will be tracked automatically by llama_decode)
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
//
typedef struct llama_batch {
int32_t n_tokens;
@ -244,15 +247,6 @@ extern "C" {
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
//
// pos[i] = all_pos_0 + i*all_pos_1
//
llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch;
enum llama_model_kv_override_type {
@ -776,15 +770,15 @@ extern "C" {
// Decoding
//
// Return batch for single sequence of tokens starting at pos_0
// Return batch for single sequence of tokens
// The sequence ID will be fixed to 0
// The position of the tokens will be tracked automatically by llama_decode
//
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
//
LLAMA_API struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,
llama_pos pos_0,
llama_seq_id seq_id);
int32_t n_tokens);
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids