tests: allow artificial slowdown of sampling for tests

This commit is contained in:
ochafik 2024-09-29 01:09:41 +01:00
parent 1da67a395c
commit 4dcb3ea943
4 changed files with 18 additions and 0 deletions

View file

@ -1879,6 +1879,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.slot_prompt_similarity = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg(
{"--testing-sampler-delay-millis"}, "N",
format("for tests: delay in milliseconds to add to each sampling (default: %d)", params.testing_sampler_delay_millis),
[](gpt_params & params, int value) {
params.testing_sampler_delay_millis = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg(
{"--lora-init-without-apply"},
format("load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"),

View file

@ -299,6 +299,8 @@ struct gpt_params {
float slot_prompt_similarity = 0.5f;
int testing_sampler_delay_millis = 0;
// batched-bench params
bool is_pp_shared = false;

View file

@ -2348,6 +2348,9 @@ struct server_context {
}
completion_token_output result;
if (params.testing_sampler_delay_millis > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis));
}
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
gpt_sampler_accept(slot.smpl, id, true);

View file

@ -78,6 +78,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.response_format = None
context.temperature = None
context.lora_file = None
context.testing_sampler_delay_millis = None
context.disable_ctx_shift = False
context.tasks_result = []
@ -455,6 +456,9 @@ def step_impl(context, n_ga):
def step_impl(context, n_ga_w):
context.n_ga_w = n_ga_w
@step('{testing_sampler_delay_millis:d} milliseconds delay in sampler for testing')
def step_testing_sampler_delay_millis(context, testing_sampler_delay_millis):
context.testing_sampler_delay_millis = testing_sampler_delay_millis
@step('a passkey prompt template')
def step_prompt_passkey(context):
@ -1436,6 +1440,8 @@ def start_server_background(context):
server_args.append('--verbose')
if context.lora_file:
server_args.extend(['--lora', context.lora_file])
if context.testing_sampler_delay_millis:
server_args.extend(['--testing-sampler-delay-millis', context.testing_sampler_delay_millis])
if context.disable_ctx_shift:
server_args.extend(['--no-context-shift'])