add LoRA test
This commit is contained in:
parent
e91c5780a6
commit
98da78b228
3 changed files with 58 additions and 0 deletions
36
examples/server/tests/features/lora.feature
Normal file
36
examples/server/tests/features/lora.feature
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
@llama.cpp
|
||||||
|
@lora
|
||||||
|
Feature: llama.cpp server
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model url https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/stories15M_MOE-F16.gguf
|
||||||
|
And a model file stories15M_MOE-F16.gguf
|
||||||
|
And a model alias stories15M_MOE
|
||||||
|
And a lora adapter file from https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf
|
||||||
|
And 42 as server seed
|
||||||
|
And 1024 as batch size
|
||||||
|
And 1024 as ubatch size
|
||||||
|
And 2048 KV cache size
|
||||||
|
And 64 max tokens to predict
|
||||||
|
And 0.0 temperature
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Scenario: Completion LoRA disabled
|
||||||
|
Given switch off lora adapter 0
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Look in thy glass
|
||||||
|
"""
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 64 tokens are predicted matching little|girl|three|years|old
|
||||||
|
|
||||||
|
Scenario: Completion LoRA enabled
|
||||||
|
Given switch on lora adapter 0
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Look in thy glass
|
||||||
|
"""
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 64 tokens are predicted matching eye|love|glass|sun
|
|
@ -7,6 +7,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import requests
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from re import RegexFlag
|
from re import RegexFlag
|
||||||
|
@ -70,6 +71,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
||||||
context.user_api_key = None
|
context.user_api_key = None
|
||||||
context.response_format = None
|
context.response_format = None
|
||||||
context.temperature = None
|
context.temperature = None
|
||||||
|
context.lora_file = None
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
|
@ -82,6 +84,12 @@ def step_download_hf_model(context, hf_file: str, hf_repo: str):
|
||||||
context.model_hf_file = hf_file
|
context.model_hf_file = hf_file
|
||||||
context.model_file = os.path.basename(hf_file)
|
context.model_file = os.path.basename(hf_file)
|
||||||
|
|
||||||
|
@step('a lora adapter file from {lora_file_url}')
|
||||||
|
def step_download_lora_file(context, lora_file_url: str):
|
||||||
|
file_name = lora_file_url.split('/').pop()
|
||||||
|
context.lora_file = f'../../../{file_name}'
|
||||||
|
with open(context.lora_file, 'wb') as f:
|
||||||
|
f.write(requests.get(lora_file_url).content)
|
||||||
|
|
||||||
@step('a model file {model_file}')
|
@step('a model file {model_file}')
|
||||||
def step_model_file(context, model_file: str):
|
def step_model_file(context, model_file: str):
|
||||||
|
@ -849,6 +857,17 @@ async def step_erase_slot(context, slot_id):
|
||||||
context.response = response
|
context.response = response
|
||||||
|
|
||||||
|
|
||||||
|
@step('switch {on_or_off} lora adapter {lora_id:d}')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def toggle_lora_adapter(context, on_or_off: str, lora_id: int):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(f'{context.base_url}/lora-adapters',
|
||||||
|
json=[{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}],
|
||||||
|
headers={"Content-Type": "application/json"}) as response:
|
||||||
|
context.response = response
|
||||||
|
print([{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}])
|
||||||
|
|
||||||
|
|
||||||
@step('the server responds with status code {status_code:d}')
|
@step('the server responds with status code {status_code:d}')
|
||||||
def step_server_responds_with_status_code(context, status_code):
|
def step_server_responds_with_status_code(context, status_code):
|
||||||
assert context.response.status == status_code
|
assert context.response.status == status_code
|
||||||
|
@ -1326,6 +1345,8 @@ def start_server_background(context):
|
||||||
server_args.extend(['--grp-attn-w', context.n_ga_w])
|
server_args.extend(['--grp-attn-w', context.n_ga_w])
|
||||||
if context.debug:
|
if context.debug:
|
||||||
server_args.append('--verbose')
|
server_args.append('--verbose')
|
||||||
|
if context.lora_file:
|
||||||
|
server_args.extend(['--lora', context.lora_file])
|
||||||
if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
|
if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
|
||||||
server_args.extend(['--log-format', "text"])
|
server_args.extend(['--log-format', "text"])
|
||||||
|
|
||||||
|
|
|
@ -4,3 +4,4 @@ huggingface_hub~=0.20.3
|
||||||
numpy~=1.26.4
|
numpy~=1.26.4
|
||||||
openai~=1.30.3
|
openai~=1.30.3
|
||||||
prometheus-client~=0.20.0
|
prometheus-client~=0.20.0
|
||||||
|
requests~=2.32.3
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue