diff --git a/examples/server/tests/features/passkey.feature b/examples/server/tests/features/passkey.feature index f64236766..72f8bce0b 100644 --- a/examples/server/tests/features/passkey.feature +++ b/examples/server/tests/features/passkey.feature @@ -19,7 +19,9 @@ Feature: Passkey / Self-extend with context shift And a self-extend context with a factor of And as seed And a KV cache size based on the model trained context extended by with additional tokens - And 1 slots + And slots + And group attention factor to extend context size through self-extend + And group attention width to extend context size through self-extend # Can be override with N_GPU_LAYERS And GPU offloaded layers Then the server is starting @@ -47,5 +49,5 @@ Feature: Passkey / Self-extend with context shift Then tokens are predicted matching Examples: - | hf_repo | hf_file | n_ctx_train | ngl | n_batch | n_junk | n_grp | i_pos | seed | n_keep | passkey | n_predicted | re_content | - | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 512 | 250 | 4 | 50 | 86 | 32 | 42 | 4 | .*42.* | + | hf_repo | hf_file | n_ctx_train | ngl | n_batch | n_slots | n_ga | n_ga_w | n_junk | n_grp | i_pos | seed | n_keep | passkey | n_predicted | re_content | + | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 512 | 1 | 4 | 2048 | 250 | 4 | 50 | 86 | 32 | 42 | -1 | .*42.* | diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 8d4ad4f2b..925ea69ef 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -30,6 +30,8 @@ def step_server_config(context, server_fqdn, server_port): context.model_alias = None context.n_batch = None context.n_ctx = None + context.n_ga = None + context.n_ga_w = None context.n_gpu_layer = None context.n_predict = None context.n_server_predict = None @@ -191,14 +193,14 @@ async def step_request_completion(context, api_error): assert completion == 401, f"completion must be an 401 status code: {completion}" -@step(u'{predicted_n} tokens are predicted matching {re_content}') +@step(u'{predicted_n:d} tokens are predicted matching {re_content}') def step_n_tokens_predicted_with_content(context, predicted_n, re_content): - assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content) + assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content) -@step(u'{predicted_n} tokens are predicted') +@step(u'{predicted_n:d} tokens are predicted') def step_n_tokens_predicted(context, predicted_n): - assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n)) + assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n) @step(u'a user prompt {user_prompt}') @@ -280,6 +282,16 @@ def step_prompt_junk_suffix(context): def step_prompt_suffix(context): context.prompt_suffix = context.text +@step(u'{n_ga:d} group attention factor' + u' to extend context size through self-extend') +def step_impl(context, n_ga): + context.n_ga = n_ga + + +@step(u'{n_ga_w:d} group attention width to extend context size through self-extend') +def step_impl(context, n_ga_w): + context.n_ga_w = n_ga_w + @step(u'a passkey prompt template') def step_prompt_passkey_template(context): @@ -804,7 +816,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re content = completion_response['content'] n_predicted = completion_response['timings']['predicted_n'] assert len(content) > 0, "no token predicted" - if expected_predicted_n is not None: + if expected_predicted_n and expected_predicted_n > 0: assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' f' {n_predicted} <> {expected_predicted_n}') if re_content is not None: @@ -926,16 +938,20 @@ def start_server_background(context): server_args.append('--embedding') if context.server_metrics: server_args.append('--metrics') - if context.model_alias is not None: + if context.model_alias: server_args.extend(['--alias', context.model_alias]) - if context.n_ctx is not None: + if context.n_ctx: server_args.extend(['--ctx-size', context.n_ctx]) - if context.n_slots is not None: + if context.n_slots: server_args.extend(['--parallel', context.n_slots]) - if context.n_server_predict is not None: + if context.n_server_predict: server_args.extend(['--n-predict', context.n_server_predict]) - if context.server_api_key is not None: + if context.server_api_key: server_args.extend(['--api-key', context.server_api_key]) + if context.n_ga: + server_args.extend(['--grp-attn-n', context.n_ga]) + if context.n_ga_w: + server_args.extend(['--grp-attn-w', context.n_ga_w]) if context.debug: server_args.append('--verbose') if 'SERVER_LOG_FORMAT_JSON' not in os.environ: