cont : fix save-load-state RNG seeding

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-08-12 19:44:44 +03:00
parent 6174762877
commit 48607c7a77
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -8,6 +8,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
params.prompt = "The quick brown fox"; params.prompt = "The quick brown fox";
params.sparams.seed = 1234;
if (!gpt_params_parse(argc, argv, params)) { if (!gpt_params_parse(argc, argv, params)) {
gpt_params_print_usage(argc, argv, params); gpt_params_print_usage(argc, argv, params);
@ -37,7 +38,10 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); llama_sampling_params sparams = llama_sampling_default_params();
sparams.seed = params.sparams.seed;
llama_sampling * smpl = llama_sampling_init(model, sparams);
// tokenize prompt // tokenize prompt
auto tokens = llama_tokenize(ctx, params.prompt, true); auto tokens = llama_tokenize(ctx, params.prompt, true);
@ -97,7 +101,7 @@ int main(int argc, char ** argv) {
// make new context // make new context
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl2 = llama_sampling_init(model, llama_sampling_default_params()); llama_sampling * smpl2 = llama_sampling_init(model, sparams);
printf("\nsecond run: %s", params.prompt.c_str()); printf("\nsecond run: %s", params.prompt.c_str());
@ -162,7 +166,7 @@ int main(int argc, char ** argv) {
// make new context // make new context
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl3 = llama_sampling_init(model, llama_sampling_default_params()); llama_sampling * smpl3 = llama_sampling_init(model, sparams);
printf("\nsingle seq run: %s", params.prompt.c_str()); printf("\nsingle seq run: %s", params.prompt.c_str());