server: tests: add group attention params

This commit is contained in:
Pierrick HYMBERT 2024-03-02 13:50:28 +01:00
parent ab5b06b2cf
commit 60113da241
2 changed files with 31 additions and 13 deletions

View file

@ -19,7 +19,9 @@ Feature: Passkey / Self-extend with context shift
And a self-extend context with a factor of <n_grp> And a self-extend context with a factor of <n_grp>
And <seed> as seed And <seed> as seed
And a KV cache size based on the model trained context <n_ctx_train> extended by <n_grp> with additional <n_keep> tokens And a KV cache size based on the model trained context <n_ctx_train> extended by <n_grp> with additional <n_keep> tokens
And 1 slots And <n_slots> slots
And <n_ga> group attention factor to extend context size through self-extend
And <n_ga_w> group attention width to extend context size through self-extend
# Can be override with N_GPU_LAYERS # Can be override with N_GPU_LAYERS
And <ngl> GPU offloaded layers And <ngl> GPU offloaded layers
Then the server is starting Then the server is starting
@ -47,5 +49,5 @@ Feature: Passkey / Self-extend with context shift
Then <n_predicted> tokens are predicted matching <re_content> Then <n_predicted> tokens are predicted matching <re_content>
Examples: Examples:
| hf_repo | hf_file | n_ctx_train | ngl | n_batch | n_junk | n_grp | i_pos | seed | n_keep | passkey | n_predicted | re_content | | hf_repo | hf_file | n_ctx_train | ngl | n_batch | n_slots | n_ga | n_ga_w | n_junk | n_grp | i_pos | seed | n_keep | passkey | n_predicted | re_content |
| TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 512 | 250 | 4 | 50 | 86 | 32 | 42 | 4 | .*42.* | | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 512 | 1 | 4 | 2048 | 250 | 4 | 50 | 86 | 32 | 42 | -1 | .*42.* |

View file

@ -30,6 +30,8 @@ def step_server_config(context, server_fqdn, server_port):
context.model_alias = None context.model_alias = None
context.n_batch = None context.n_batch = None
context.n_ctx = None context.n_ctx = None
context.n_ga = None
context.n_ga_w = None
context.n_gpu_layer = None context.n_gpu_layer = None
context.n_predict = None context.n_predict = None
context.n_server_predict = None context.n_server_predict = None
@ -191,14 +193,14 @@ async def step_request_completion(context, api_error):
assert completion == 401, f"completion must be an 401 status code: {completion}" assert completion == 401, f"completion must be an 401 status code: {completion}"
@step(u'{predicted_n} tokens are predicted matching {re_content}') @step(u'{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content): def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content) assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
@step(u'{predicted_n} tokens are predicted') @step(u'{predicted_n:d} tokens are predicted')
def step_n_tokens_predicted(context, predicted_n): def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n)) assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
@step(u'a user prompt {user_prompt}') @step(u'a user prompt {user_prompt}')
@ -280,6 +282,16 @@ def step_prompt_junk_suffix(context):
def step_prompt_suffix(context): def step_prompt_suffix(context):
context.prompt_suffix = context.text context.prompt_suffix = context.text
@step(u'{n_ga:d} group attention factor'
u' to extend context size through self-extend')
def step_impl(context, n_ga):
context.n_ga = n_ga
@step(u'{n_ga_w:d} group attention width to extend context size through self-extend')
def step_impl(context, n_ga_w):
context.n_ga_w = n_ga_w
@step(u'a passkey prompt template') @step(u'a passkey prompt template')
def step_prompt_passkey_template(context): def step_prompt_passkey_template(context):
@ -804,7 +816,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
content = completion_response['content'] content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n'] n_predicted = completion_response['timings']['predicted_n']
assert len(content) > 0, "no token predicted" assert len(content) > 0, "no token predicted"
if expected_predicted_n is not None: if expected_predicted_n and expected_predicted_n > 0:
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
f' {n_predicted} <> {expected_predicted_n}') f' {n_predicted} <> {expected_predicted_n}')
if re_content is not None: if re_content is not None:
@ -926,16 +938,20 @@ def start_server_background(context):
server_args.append('--embedding') server_args.append('--embedding')
if context.server_metrics: if context.server_metrics:
server_args.append('--metrics') server_args.append('--metrics')
if context.model_alias is not None: if context.model_alias:
server_args.extend(['--alias', context.model_alias]) server_args.extend(['--alias', context.model_alias])
if context.n_ctx is not None: if context.n_ctx:
server_args.extend(['--ctx-size', context.n_ctx]) server_args.extend(['--ctx-size', context.n_ctx])
if context.n_slots is not None: if context.n_slots:
server_args.extend(['--parallel', context.n_slots]) server_args.extend(['--parallel', context.n_slots])
if context.n_server_predict is not None: if context.n_server_predict:
server_args.extend(['--n-predict', context.n_server_predict]) server_args.extend(['--n-predict', context.n_server_predict])
if context.server_api_key is not None: if context.server_api_key:
server_args.extend(['--api-key', context.server_api_key]) server_args.extend(['--api-key', context.server_api_key])
if context.n_ga:
server_args.extend(['--grp-attn-n', context.n_ga])
if context.n_ga_w:
server_args.extend(['--grp-attn-w', context.n_ga_w])
if context.debug: if context.debug:
server_args.append('--verbose') server_args.append('--verbose')
if 'SERVER_LOG_FORMAT_JSON' not in os.environ: if 'SERVER_LOG_FORMAT_JSON' not in os.environ: