sampling: separate rng per sampling context
This commit is contained in:
parent
b1a189115e
commit
123eaf054f
8 changed files with 34 additions and 10 deletions
|
@ -1,4 +1,6 @@
|
||||||
|
#define LLAMA_API_INTERNAL
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
|
#include <random>
|
||||||
|
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||||
struct llama_sampling_context * result = new llama_sampling_context();
|
struct llama_sampling_context * result = new llama_sampling_context();
|
||||||
|
@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
||||||
|
|
||||||
result->prev.resize(params.n_prev);
|
result->prev.resize(params.n_prev);
|
||||||
|
|
||||||
|
llama_sampling_set_rng_seed(result, LLAMA_DEFAULT_SEED);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||||
ctx->cur.clear();
|
ctx->cur.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
||||||
|
if (seed == LLAMA_DEFAULT_SEED) {
|
||||||
|
seed = time(NULL);
|
||||||
|
}
|
||||||
|
ctx->rng.seed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
|
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
|
||||||
if (dst->grammar) {
|
if (dst->grammar) {
|
||||||
llama_grammar_free(dst->grammar);
|
llama_grammar_free(dst->grammar);
|
||||||
|
@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl(
|
||||||
|
|
||||||
sampler_queue(ctx_main, params, cur_p, min_keep);
|
sampler_queue(ctx_main, params, cur_p, min_keep);
|
||||||
|
|
||||||
id = llama_sample_token(ctx_main, &cur_p);
|
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
|
||||||
|
|
||||||
//{
|
//{
|
||||||
// const int n_top = 10;
|
// const int n_top = 10;
|
||||||
|
|
|
@ -4,9 +4,10 @@
|
||||||
|
|
||||||
#include "grammar-parser.h"
|
#include "grammar-parser.h"
|
||||||
|
|
||||||
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
// sampler types
|
// sampler types
|
||||||
enum class llama_sampler_type : char {
|
enum class llama_sampler_type : char {
|
||||||
|
@ -79,6 +80,8 @@ struct llama_sampling_context {
|
||||||
// TODO: replace with ring-buffer
|
// TODO: replace with ring-buffer
|
||||||
std::vector<llama_token> prev;
|
std::vector<llama_token> prev;
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
|
|
||||||
|
std::mt19937 rng;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
@ -93,6 +96,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||||
// - reset grammar
|
// - reset grammar
|
||||||
void llama_sampling_reset(llama_sampling_context * ctx);
|
void llama_sampling_reset(llama_sampling_context * ctx);
|
||||||
|
|
||||||
|
// Set the sampler seed
|
||||||
|
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
|
||||||
|
|
||||||
// Copy the sampler context
|
// Copy the sampler context
|
||||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
|
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,6 @@ int main(int argc, char ** argv){
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
llama_set_rng_seed(ctx, params.seed);
|
|
||||||
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
|
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
|
|
|
@ -38,7 +38,6 @@ int main(int argc, char ** argv){
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
llama_set_rng_seed(ctx, params.seed);
|
|
||||||
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
|
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
|
@ -108,6 +107,7 @@ int main(int argc, char ** argv){
|
||||||
bool has_eos = false;
|
bool has_eos = false;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
|
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
|
||||||
|
|
||||||
std::vector<llama_token> draft;
|
std::vector<llama_token> draft;
|
||||||
|
|
||||||
|
|
|
@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
session_tokens.resize(n_token_count_out);
|
session_tokens.resize(n_token_count_out);
|
||||||
llama_set_rng_seed(ctx, params.seed);
|
|
||||||
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
|
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -521,6 +520,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||||
|
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
|
||||||
|
|
||||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
|
|
|
@ -1028,7 +1028,7 @@ struct server_context {
|
||||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
llama_set_rng_seed(ctx, slot.params.seed);
|
llama_sampling_set_rng_seed(slot.ctx_sampling, slot.params.seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.command = SLOT_COMMAND_LOAD_PROMPT;
|
slot.command = SLOT_COMMAND_LOAD_PROMPT;
|
||||||
|
|
|
@ -13478,7 +13478,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
|
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||||
GGML_ASSERT(ctx);
|
GGML_ASSERT(ctx);
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
@ -13491,7 +13491,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
|
||||||
}
|
}
|
||||||
|
|
||||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||||
auto & rng = ctx->rng;
|
|
||||||
int idx = dist(rng);
|
int idx = dist(rng);
|
||||||
|
|
||||||
llama_token result = candidates->data[idx].id;
|
llama_token result = candidates->data[idx].id;
|
||||||
|
@ -13501,6 +13500,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||||
|
return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
|
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
|
9
llama.h
9
llama.h
|
@ -987,7 +987,7 @@ extern "C" {
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_token_data_array * candidates);
|
llama_token_data_array * candidates);
|
||||||
|
|
||||||
/// @details Randomly selects a token from the candidates based on their probabilities.
|
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
|
||||||
LLAMA_API llama_token llama_sample_token(
|
LLAMA_API llama_token llama_sample_token(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_token_data_array * candidates);
|
llama_token_data_array * candidates);
|
||||||
|
@ -1074,8 +1074,9 @@ extern "C" {
|
||||||
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
||||||
#ifdef LLAMA_API_INTERNAL
|
#ifdef LLAMA_API_INTERNAL
|
||||||
|
|
||||||
#include <vector>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
struct ggml_tensor;
|
struct ggml_tensor;
|
||||||
|
|
||||||
|
@ -1112,6 +1113,10 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||||
const std::string & src,
|
const std::string & src,
|
||||||
llama_partial_utf8 partial_start);
|
llama_partial_utf8 partial_start);
|
||||||
|
|
||||||
|
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
|
||||||
|
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
|
||||||
|
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||||
|
|
||||||
#endif // LLAMA_API_INTERNAL
|
#endif // LLAMA_API_INTERNAL
|
||||||
|
|
||||||
#endif // LLAMA_H
|
#endif // LLAMA_H
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue