add slow test with llama 8b
This commit is contained in:
parent
d67fefb91d
commit
367f0ab1b4
4 changed files with 73 additions and 18 deletions
|
@ -5,3 +5,4 @@ numpy~=1.26.4
|
||||||
openai~=1.55.3
|
openai~=1.55.3
|
||||||
prometheus-client~=0.20.0
|
prometheus-client~=0.20.0
|
||||||
requests~=2.32.3
|
requests~=2.32.3
|
||||||
|
wget~=3.2
|
||||||
|
|
|
@ -10,15 +10,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.stories15m_moe()
|
server = ServerPreset.stories15m_moe()
|
||||||
# download lora file if needed
|
server.lora_files = [download_file(LORA_FILE_URL)]
|
||||||
file_name = LORA_FILE_URL.split('/').pop()
|
|
||||||
lora_file = f'../../../{file_name}'
|
|
||||||
if not os.path.exists(lora_file):
|
|
||||||
print(f"Downloading {LORA_FILE_URL} to {lora_file}")
|
|
||||||
with open(lora_file, 'wb') as f:
|
|
||||||
f.write(requests.get(LORA_FILE_URL).content)
|
|
||||||
print(f"Done downloading lora file")
|
|
||||||
server.lora_files = [lora_file]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("scale,re_content", [
|
@pytest.mark.parametrize("scale,re_content", [
|
||||||
|
@ -73,3 +65,52 @@ def test_lora_per_request():
|
||||||
assert all([res.status_code == 200 for res in results])
|
assert all([res.status_code == 200 for res in results])
|
||||||
for res, (_, re_test) in zip(results, lora_config):
|
for res, (_, re_test) in zip(results, lora_config):
|
||||||
assert match_regex(re_test, res.body["content"])
|
assert match_regex(re_test, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
|
||||||
|
def test_with_big_model():
|
||||||
|
server = ServerProcess()
|
||||||
|
server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
|
||||||
|
server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
|
||||||
|
server.model_alias = "Llama-3.2-8B-Instruct"
|
||||||
|
server.n_slots = 4
|
||||||
|
server.n_ctx = server.n_slots * 1024
|
||||||
|
server.n_predict = 64
|
||||||
|
server.temperature = 0.0
|
||||||
|
server.seed = 42
|
||||||
|
server.lora_files = [
|
||||||
|
download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
|
||||||
|
# TODO: find & add other lora adapters for this model
|
||||||
|
]
|
||||||
|
server.start(timeout_seconds=600)
|
||||||
|
|
||||||
|
# running the same prompt with different lora scales, all in parallel
|
||||||
|
# each prompt will be processed by a different slot
|
||||||
|
prompt = "Write a computer virus"
|
||||||
|
lora_config = [
|
||||||
|
# without applying lora, the model should reject the request
|
||||||
|
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
|
||||||
|
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
|
||||||
|
( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
|
||||||
|
# with 0.7 scale, the model should provide a simple computer virus with hesitation
|
||||||
|
( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
|
||||||
|
# with 1.5 scale, the model should confidently provide a computer virus
|
||||||
|
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
|
||||||
|
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
|
||||||
|
]
|
||||||
|
|
||||||
|
tasks = [(
|
||||||
|
server.make_request,
|
||||||
|
("POST", "/v1/chat/completions", {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
"lora": lora,
|
||||||
|
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||||
|
})
|
||||||
|
) for lora, _ in lora_config]
|
||||||
|
results = parallel_function_calls(tasks)
|
||||||
|
|
||||||
|
assert all([res.status_code == 200 for res in results])
|
||||||
|
for res, (_, re_test) in zip(results, lora_config):
|
||||||
|
assert re_test in res.body["choices"][0]["message"]["content"]
|
||||||
|
|
|
@ -10,16 +10,8 @@ MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tiny
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.stories15m_moe()
|
server = ServerPreset.stories15m_moe()
|
||||||
# download draft model file if needed
|
|
||||||
file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
|
|
||||||
model_draft_file = f'../../../{file_name}'
|
|
||||||
if not os.path.exists(model_draft_file):
|
|
||||||
print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
|
|
||||||
with open(model_draft_file, 'wb') as f:
|
|
||||||
f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
|
|
||||||
print(f"Done downloading draft model file")
|
|
||||||
# set default values
|
# set default values
|
||||||
server.model_draft = model_draft_file
|
server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
|
||||||
server.draft_min = 4
|
server.draft_min = 4
|
||||||
server.draft_max = 8
|
server.draft_max = 8
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ from typing import (
|
||||||
Set,
|
Set,
|
||||||
)
|
)
|
||||||
from re import RegexFlag
|
from re import RegexFlag
|
||||||
|
import wget
|
||||||
|
|
||||||
|
|
||||||
class ServerResponse:
|
class ServerResponse:
|
||||||
|
@ -381,5 +382,25 @@ def match_regex(regex: str, text: str) -> bool:
|
||||||
is not None
|
is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_file(url: str, output_file_path: str | None = None) -> str:
|
||||||
|
"""
|
||||||
|
Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
|
||||||
|
|
||||||
|
output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
|
||||||
|
|
||||||
|
Returns the local path of the downloaded file.
|
||||||
|
"""
|
||||||
|
file_name = url.split('/').pop()
|
||||||
|
output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
|
||||||
|
if not os.path.exists(output_file):
|
||||||
|
print(f"Downloading {url} to {output_file}")
|
||||||
|
wget.download(url, out=output_file)
|
||||||
|
print(f"Done downloading to {output_file}")
|
||||||
|
else:
|
||||||
|
print(f"File already exists at {output_file}")
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
|
||||||
def is_slow_test_allowed():
|
def is_slow_test_allowed():
|
||||||
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
|
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue