sampling : one sequence per sampling context

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-12 20:35:01 +03:00
parent 370359e5ba
commit 5261aee8d8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 28 additions and 86 deletions

View file

@ -69,6 +69,8 @@ struct client {
std::string response;
std::vector<llama_token> tokens_prev;
llama_sampling_context ctx_sampling;
};
static void print_date_time() {
@ -125,8 +127,6 @@ int main(int argc, char ** argv) {
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
// load the prompts from an external file if there are any
if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@ -156,6 +156,7 @@ int main(int argc, char ** argv) {
client.id = i;
client.tokens_prev.resize(std::max(256, params.n_predict));
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
client.ctx_sampling = llama_sampling_context_init(params, NULL);
}
std::vector<llama_token_data> candidates;
@ -341,7 +342,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
const llama_token id = llama_sampling_sample(ctx, NULL, client.ctx_sampling, client.tokens_prev, candidates, client.i_batch - i);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@ -386,7 +387,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded;
llama_sampling_context_reset(ctx_sampling, client.seq_id);
client.seq_id = -1;
}

View file

@ -9,6 +9,12 @@
#include <string>
#include <vector>
struct seq_draft {
std::vector<llama_token> tokens;
struct llama_grammar * grammar = NULL;
};
int main(int argc, char ** argv) {
gpt_params params;
@ -213,13 +219,8 @@ int main(int argc, char ** argv) {
if (grammar_dft) {
llama_grammar_free(grammar_dft);
}
// Note: Hardcoded to sequence id 0, if this ever supports parallel generation
// that will need to change.
auto it = ctx_sampling.sequence_contexts.find(0);
GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
// This is necessary because each sequence id in sequence_contexts
// uses a copy of the original grammar.
grammar_dft = llama_grammar_copy(it->second.grammar);
grammar_dft = llama_grammar_copy(ctx_sampling.grammar);
LOG("copied target grammar to draft grammar\n");
}