Server: fix seed for multiple slots (#6835)

* Server: add tests for consistent results

* sampling: separate rng per sampling context
This commit is contained in:
Johannes Gäßler 2024-04-24 11:08:36 +02:00 committed by GitHub
parent c0d1b3e03e
commit 28103f4832
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 145 additions and 30 deletions

View file

@ -13667,7 +13667,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
return result;
}
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();
@ -13680,7 +13680,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
auto & rng = ctx->rng;
int idx = dist(rng);
llama_token result = candidates->data[idx].id;
@ -13690,6 +13689,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
return result;
}
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
}
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();