This commit is contained in:
ochafik 2024-10-27 16:44:54 +00:00
parent 080982ebf3
commit ec9f3b101b
3 changed files with 11 additions and 2 deletions

View file

@ -33,18 +33,23 @@ import shutil
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
def raise_exception(message: str):
raise ValueError(message)
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26')
def strftime_now(format):
now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d")
return now.strftime(format)
def handle_chat_template(output_folder, model_id, variant, template_src):
model_name = model_id.replace("/", "-")
base_name = f'{model_name}-{variant}' if variant else model_name
@ -111,6 +116,7 @@ def handle_chat_template(output_folder, model_id, variant, template_src):
# Output the line of arguments for the C++ test binary
print(f"{template_file} {context_file} {output_file}")
def main():
parser = argparse.ArgumentParser(description="Generate chat templates and output test arguments.")
parser.add_argument("output_folder", help="Folder to store all output files")
@ -144,5 +150,6 @@ def main():
except Exception as e:
logger.error(f"Error processing model {model_id}: {e}")
if __name__ == '__main__':
main()

View file

@ -52,7 +52,9 @@ def main(args):
ct['name']: ct['template']
for ct in chat_template
}
format_variants = lambda: ', '.join(f'"{v}"' for v in variants.keys())
def format_variants():
return ', '.join(f'"{v}"' for v in variants.keys())
if variant is None:
if 'default' not in variants:

View file

@ -253,7 +253,7 @@ static void test_parsing() {
};
auto special_function_call_with_id = json::parse(special_function_call.dump());
special_function_call_with_id["id"] = "123456789";
auto no_function_call = json::array();
test_parse_tool_call(llama_tool_call_style::Llama31, tools,