erase constant seed, add changing alpha and beta parameters from command line

This commit is contained in:
Bartosz Podkanowicz 2023-11-09 15:37:53 +01:00
parent 38c5b7ee5f
commit db2a5beef1

View file

@ -11,7 +11,7 @@ int main(int argc, char ** argv) {
gpt_params params_expert; gpt_params params_expert;
gpt_params params_amateur; gpt_params params_amateur;
if (argc == 1 || argv[1][0] == '-') { if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s EXPERT_MODEL_PATH AMATEUR_MODEL_PATH [PROMPT]\n" , argv[0]); printf("usage: %s EXPERT_MODEL_PATH AMATEUR_MODEL_PATH [PROMPT] [alpha] [beta]\n" , argv[0]);
return 1; return 1;
} }
@ -28,6 +28,17 @@ int main(int argc, char ** argv) {
params_amateur.prompt = argv[3]; params_amateur.prompt = argv[3];
} }
float alpha = 0.1;
float beta = 0.5;
if(argc >= 5){
alpha = std::stof(argv[4]);
}
if(argc >= 6){
beta = std::stof(argv[5]);
}
if (params_expert.prompt.empty()) { if (params_expert.prompt.empty()) {
params_expert.prompt = "Hello my name is"; params_expert.prompt = "Hello my name is";
params_amateur.prompt = "Hello my name is"; params_amateur.prompt = "Hello my name is";
@ -64,7 +75,6 @@ int main(int argc, char ** argv) {
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048; ctx_params.n_ctx = 2048;
ctx_params.n_threads = params_expert.n_threads; ctx_params.n_threads = params_expert.n_threads;
ctx_params.n_threads_batch = params_expert.n_threads_batch == -1 ? params_expert.n_threads : params_expert.n_threads_batch; ctx_params.n_threads_batch = params_expert.n_threads_batch == -1 ? params_expert.n_threads : params_expert.n_threads_batch;
@ -138,10 +148,6 @@ int main(int argc, char ** argv) {
int n_decode = 0; int n_decode = 0;
const auto t_main_start = ggml_time_us(); const auto t_main_start = ggml_time_us();
float alpha = 0.1;
float beta = 0.5;
while (n_cur <= n_len) { while (n_cur <= n_len) {
// sample the next token // sample the next token
{ {