speculative : add tree-based sampling support

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-14 17:54:02 +03:00
parent 5261aee8d8
commit 4de5a2d473
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 469 additions and 192 deletions

View file

@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0);
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@ -188,10 +188,11 @@ int main(int argc, char ** argv) {
batch.n_tokens = n_tokens_system;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = false;
}
if (llama_decode(ctx, batch) != 0) {
@ -218,10 +219,11 @@ int main(int argc, char ** argv) {
continue;
}
batch.token [batch.n_tokens] = client.sampled;
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
batch.seq_id[batch.n_tokens] = client.id;
batch.logits[batch.n_tokens] = true;
batch.token [batch.n_tokens] = client.sampled;
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
batch.n_seq_id[batch.n_tokens] = 1;
batch.seq_id [batch.n_tokens][0] = client.id;
batch.logits [batch.n_tokens] = true;
client.n_decoded += 1;
client.i_batch = batch.n_tokens;
@ -258,10 +260,11 @@ int main(int argc, char ** argv) {
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch.token [batch.n_tokens] = tokens_prompt[i];
batch.pos [batch.n_tokens] = i + n_tokens_system;
batch.seq_id[batch.n_tokens] = client.id;
batch.logits[batch.n_tokens] = false;
batch.token [batch.n_tokens] = tokens_prompt[i];
batch.pos [batch.n_tokens] = i + n_tokens_system;
batch.n_seq_id[batch.n_tokens] = client.id;
batch.seq_id [batch.n_tokens][0] = client.id;
batch.logits [batch.n_tokens] = false;
batch.n_tokens += 1;
}
@ -305,11 +308,12 @@ int main(int argc, char ** argv) {
llama_batch batch_view = {
n_tokens,
batch.token + i,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};