speculative : add tree-based sampling support

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-14 17:54:02 +03:00
parent 5261aee8d8
commit 4de5a2d473
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 469 additions and 192 deletions

17
llama.h
View file

@ -133,11 +133,12 @@ extern "C" {
typedef struct llama_batch {
int32_t n_tokens;
llama_token * token;
float * embd;
llama_pos * pos;
llama_seq_id * seq_id;
int8_t * logits;
llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits;
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
@ -446,7 +447,8 @@ extern "C" {
llama_pos pos_0,
llama_seq_id seq_id);
// Allocates a batch of tokens on the heap
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids
// The batch has to be freed with llama_batch_free()
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
@ -454,7 +456,8 @@ extern "C" {
// All members are left uninitialized
LLAMA_API struct llama_batch llama_batch_init(
int32_t n_tokens,
int32_t embd);
int32_t embd,
int32_t n_seq_max);
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);