speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
This commit is contained in:
parent
c67fe68e41
commit
0e89203b51
21 changed files with 737 additions and 578 deletions
|
@ -51,6 +51,12 @@ static std::vector<std::string> k_prompts = {
|
|||
};
|
||||
|
||||
struct client {
|
||||
~client() {
|
||||
if (ctx_sampling) {
|
||||
llama_sampling_free(ctx_sampling);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t id = 0;
|
||||
|
||||
llama_seq_id seq_id = -1;
|
||||
|
@ -68,7 +74,7 @@ struct client {
|
|||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
std::vector<llama_token> tokens_prev;
|
||||
struct llama_sampling_context * ctx_sampling = nullptr;
|
||||
};
|
||||
|
||||
static void print_date_time() {
|
||||
|
@ -125,8 +131,6 @@ int main(int argc, char ** argv) {
|
|||
params.logits_all = true;
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
|
||||
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
|
||||
|
||||
// load the prompts from an external file if there are any
|
||||
if (params.prompt.empty()) {
|
||||
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
|
||||
|
@ -147,20 +151,15 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "\n\n");
|
||||
fflush(stderr);
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_vocab = llama_n_vocab(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
std::vector<client> clients(n_clients);
|
||||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.tokens_prev.resize(std::max(256, params.n_predict));
|
||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||
client.ctx_sampling = llama_sampling_init(params);
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
tokens_system = ::llama_tokenize(ctx, k_system, true);
|
||||
const int32_t n_tokens_system = tokens_system.size();
|
||||
|
@ -169,7 +168,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0);
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
|
||||
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
|
@ -184,13 +183,8 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
|
||||
|
||||
batch.n_tokens = n_tokens_system;
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = tokens_system[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
for (int32_t i = 0; i < n_tokens_system; ++i) {
|
||||
llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
|
@ -209,7 +203,7 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("Processing requests ...\n\n");
|
||||
|
||||
while (true) {
|
||||
batch.n_tokens = 0;
|
||||
llama_batch_clear(batch);
|
||||
|
||||
// decode any currently ongoing sequences
|
||||
for (auto & client : clients) {
|
||||
|
@ -217,15 +211,11 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
batch.token [batch.n_tokens] = client.sampled;
|
||||
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
|
||||
batch.seq_id[batch.n_tokens] = client.id;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
|
||||
client.n_decoded += 1;
|
||||
client.i_batch = batch.n_tokens;
|
||||
|
||||
batch.n_tokens += 1;
|
||||
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
|
||||
|
||||
client.n_decoded += 1;
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
|
@ -250,18 +240,14 @@ int main(int argc, char ** argv) {
|
|||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||
llama_sampling_reset(client.ctx_sampling);
|
||||
|
||||
// do not prepend BOS because we have a system prompt!
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
batch.token [batch.n_tokens] = tokens_prompt[i];
|
||||
batch.pos [batch.n_tokens] = i + n_tokens_system;
|
||||
batch.seq_id[batch.n_tokens] = client.id;
|
||||
batch.logits[batch.n_tokens] = false;
|
||||
batch.n_tokens += 1;
|
||||
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
|
@ -304,11 +290,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
|
@ -341,7 +328,9 @@ int main(int argc, char ** argv) {
|
|||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
|
||||
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
||||
|
||||
llama_sampling_accept(client.ctx_sampling, ctx, id);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
|
@ -349,11 +338,8 @@ int main(int argc, char ** argv) {
|
|||
client.t_start_gen = ggml_time_us();
|
||||
}
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
client.tokens_prev.erase(client.tokens_prev.begin());
|
||||
client.tokens_prev.push_back(id);
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
|
||||
client.response += token_str;
|
||||
client.sampled = id;
|
||||
|
||||
|
@ -386,7 +372,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
n_total_prompt += client.n_prompt;
|
||||
n_total_gen += client.n_decoded;
|
||||
llama_sampling_context_reset(ctx_sampling, client.seq_id);
|
||||
|
||||
client.seq_id = -1;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue