Update fetch_server_test_models.py

This commit is contained in:
ochafik 2024-12-27 00:58:59 +00:00
parent 0e87ae24cd
commit 0a5d527508

124
scripts/fetch_server_test_models.py Normal file → Executable file
View file

@ -1,3 +1,4 @@
#!/usr/bin/env python
''' '''
This script fetches all the models used in the server tests. This script fetches all the models used in the server tests.
@ -7,13 +8,14 @@
Example: Example:
python scripts/fetch_server_test_models.py python scripts/fetch_server_test_models.py
( cd examples/server/tests && ./tests.sh --tags=slow ) ( cd examples/server/tests && ./tests.sh -v -x -m slow )
''' '''
from behave.parser import Parser import ast
import glob import glob
import logging
import os import os
from typing import Generator
from pydantic import BaseModel from pydantic import BaseModel
import re
import subprocess import subprocess
import sys import sys
@ -26,53 +28,71 @@ class HuggingFaceModel(BaseModel):
frozen = True frozen = True
models = set() def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
model_file_re = re.compile(r'a model file ([^\s\n\r]+) from HF repo ([^\s\n\r]+)')
def process_step(step):
if (match := model_file_re.search(step.name)):
(hf_file, hf_repo) = match.groups()
models.add(HuggingFaceModel(hf_repo=hf_repo, hf_file=hf_file))
feature_files = glob.glob(
os.path.join(
os.path.dirname(__file__),
'../examples/server/tests/features/*.feature'))
for feature_file in feature_files:
with open(feature_file, 'r') as file:
feature = Parser().parse(file.read())
if not feature: continue
if feature.background:
for step in feature.background.steps:
process_step(step)
for scenario in feature.walk_scenarios(with_outlines=True):
for step in scenario.steps:
process_step(step)
cli_path = os.environ.get(
'LLAMA_SERVER_BIN_PATH',
os.path.join(
os.path.dirname(__file__),
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
for m in sorted(list(models), key=lambda m: m.hf_repo):
if '<' in m.hf_repo or '<' in m.hf_file:
continue
if '-of-' in m.hf_file:
print(f'# Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file', file=sys.stderr)
continue
print(f'# Ensuring model at {m.hf_repo} / {m.hf_file} is fetched')
cmd = [cli_path, '-hfr', m.hf_repo, '-hff', m.hf_file, '-n', '1', '-p', 'Hey', '--no-warmup', '--log-disable']
if m.hf_file != 'tinyllamas/stories260K.gguf' and not m.hf_file.startswith('Mistral-Nemo'):
cmd.append('-fa')
try: try:
subprocess.check_call(cmd) with open(test_file) as f:
except subprocess.CalledProcessError: tree = ast.parse(f.read())
print(f'# Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}', file=sys.stderr) except Exception as e:
exit(1) logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
return
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
for dec in node.decorator_list:
if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
param_names = ast.literal_eval(dec.args[0]).split(",")
if not "hf_repo" in param_names or not "hf_file" in param_names:
continue
raw_param_values = dec.args[1]
if not isinstance(raw_param_values, ast.List):
logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
continue
hf_repo_idx = param_names.index("hf_repo")
hf_file_idx = param_names.index("hf_file")
for t in raw_param_values.elts:
if not isinstance(t, ast.Tuple):
logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
continue
yield HuggingFaceModel(
hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
hf_file=ast.literal_eval(t.elts[hf_file_idx]))
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
models = sorted(list(set([
model
for test_file in glob.glob('examples/server/tests/unit/test_*.py')
for model in collect_hf_model_test_parameters(test_file)
])), key=lambda m: (m.hf_repo, m.hf_file))
logging.info(f'Found {len(models)} models in parameterized tests:')
for m in models:
logging.info(f' - {m.hf_repo} / {m.hf_file}')
cli_path = os.environ.get(
'LLAMA_SERVER_BIN_PATH',
os.path.join(
os.path.dirname(__file__),
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' \
else '../build/bin/llama-cli'))
for m in models:
if '<' in m.hf_repo or '<' in m.hf_file:
continue
if '-of-' in m.hf_file:
logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
continue
logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
cmd = [cli_path, '-hfr', m.hf_repo, '-hff', m.hf_file, '-n', '1', '-p', 'Hey', '--no-warmup', '--log-disable']
if m.hf_file != 'tinyllamas/stories260K.gguf' and not m.hf_file.startswith('Mistral-Nemo'):
cmd.append('-fa')
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError:
logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}')
exit(1)