speculative : add tree-based sampling example (#3624)

* sampling : one sequence per sampling context

ggml-ci

* speculative : add tree-based sampling support

ggml-ci

* speculative : reuse the n_parallel CLI param

* speculative : refactor sampling

* examples : fix build after sampling refactoring

ggml-ci

* batched : fix n_seq_id

* sampling : fix malloc

ggml-ci

* swift : fix build

ggml-ci

* swift : try to fix build

ggml-ci

* prompts : add assistant.txt

* common : add llama_batch_add() and llama_batch_clear() helpers

* speculative : minor refactor

ggml-ci

* minor : comments + rename

ggml-ci

* speculative : fix off-by-one for n_drafted

* speculative : fix the n_drafted fix + p constants
This commit is contained in:
Georgi Gerganov 2023-10-18 16:21:57 +03:00 committed by GitHub
parent c67fe68e41
commit 0e89203b51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 737 additions and 578 deletions

View file

@ -51,6 +51,12 @@ static std::vector<std::string> k_prompts = {
};
struct client {
~client() {
if (ctx_sampling) {
llama_sampling_free(ctx_sampling);
}
}
int32_t id = 0;
llama_seq_id seq_id = -1;
@ -68,7 +74,7 @@ struct client {
std::string prompt;
std::string response;
std::vector<llama_token> tokens_prev;
struct llama_sampling_context * ctx_sampling = nullptr;
};
static void print_date_time() {
@ -125,8 +131,6 @@ int main(int argc, char ** argv) {
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
// load the prompts from an external file if there are any
if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@ -147,20 +151,15 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n\n");
fflush(stderr);
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(model);
const int n_ctx = llama_n_ctx(ctx);
std::vector<client> clients(n_clients);
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.tokens_prev.resize(std::max(256, params.n_predict));
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
client.ctx_sampling = llama_sampling_init(params);
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
std::vector<llama_token> tokens_system;
tokens_system = ::llama_tokenize(ctx, k_system, true);
const int32_t n_tokens_system = tokens_system.size();
@ -169,7 +168,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0);
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@ -184,13 +183,8 @@ int main(int argc, char ** argv) {
{
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
batch.n_tokens = n_tokens_system;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
for (int32_t i = 0; i < n_tokens_system; ++i) {
llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
}
if (llama_decode(ctx, batch) != 0) {
@ -209,7 +203,7 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n");
while (true) {
batch.n_tokens = 0;
llama_batch_clear(batch);
// decode any currently ongoing sequences
for (auto & client : clients) {
@ -217,15 +211,11 @@ int main(int argc, char ** argv) {
continue;
}
batch.token [batch.n_tokens] = client.sampled;
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
batch.seq_id[batch.n_tokens] = client.id;
batch.logits[batch.n_tokens] = true;
client.n_decoded += 1;
client.i_batch = batch.n_tokens;
batch.n_tokens += 1;
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
@ -250,18 +240,14 @@ int main(int argc, char ** argv) {
client.prompt = client.input + "\nAssistant:";
client.response = "";
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
llama_sampling_reset(client.ctx_sampling);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch.token [batch.n_tokens] = tokens_prompt[i];
batch.pos [batch.n_tokens] = i + n_tokens_system;
batch.seq_id[batch.n_tokens] = client.id;
batch.logits[batch.n_tokens] = false;
batch.n_tokens += 1;
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
}
// extract the logits only for the last token
@ -304,11 +290,12 @@ int main(int argc, char ** argv) {
llama_batch batch_view = {
n_tokens,
batch.token + i,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
@ -341,7 +328,9 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
llama_sampling_accept(client.ctx_sampling, ctx, id);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@ -349,11 +338,8 @@ int main(int argc, char ** argv) {
client.t_start_gen = ggml_time_us();
}
// remember which tokens were sampled - used for repetition penalties during sampling
client.tokens_prev.erase(client.tokens_prev.begin());
client.tokens_prev.push_back(id);
const std::string token_str = llama_token_to_piece(ctx, id);
client.response += token_str;
client.sampled = id;
@ -386,7 +372,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded;
llama_sampling_context_reset(ctx_sampling, client.seq_id);
client.seq_id = -1;
}