simple : add parallel decoding support

This commit is contained in:
Georgi Gerganov 2023-09-20 13:06:34 +03:00
parent addae65fd4
commit b377bf2266
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 187 additions and 76 deletions

15
llama.h
View file

@ -68,7 +68,7 @@ extern "C" {
// data used for batch inference
typedef struct llama_batch {
uint32_t n_tokens;
int32_t n_tokens;
llama_token * token;
float * embd;
@ -370,7 +370,7 @@ extern "C" {
LLAMA_API DEPRECATED(int llama_eval(
struct llama_context * ctx,
llama_token * tokens,
uint32_t n_tokens,
int32_t n_tokens,
int n_past,
int n_threads),
"please use llama_decode() instead");
@ -380,7 +380,7 @@ extern "C" {
LLAMA_API DEPRECATED(int llama_eval_embd(
struct llama_context * ctx,
float * embd,
uint32_t n_tokens,
int32_t n_tokens,
int n_past,
int n_threads),
"please use llama_decode() instead");
@ -391,7 +391,7 @@ extern "C" {
//
LLAMA_API struct llama_batch llama_batch_get_one(
llama_token * tokens,
uint32_t n_tokens,
int32_t n_tokens,
llama_pos pos_0,
llama_seq_id seq_id);
@ -401,7 +401,7 @@ extern "C" {
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
// The rest of the llama_batch members are allocated with size n_tokens
// All members are left uninitialized
LLAMA_API struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd);
LLAMA_API struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd);
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
@ -531,7 +531,10 @@ extern "C" {
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
LLAMA_API void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
LLAMA_API DEPRECATED(void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp),
"Use llama_sample_temp instead");
/// @details Apply constraints from grammar
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);