From 625954985c39700119074f2dd32c999392f0d8da Mon Sep 17 00:00:00 2001 From: caiyesd Date: Thu, 23 Nov 2023 10:30:43 +0800 Subject: [PATCH] Add arch argument for convert.py Signed-off-by: caiyesd --- convert.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/convert.py b/convert.py index 3ad836ce0..d1a34ba4e 100644 --- a/convert.py +++ b/convert.py @@ -817,10 +817,10 @@ class OutputFile: def add_meta_arch(self, params: Params) -> None: name = "LLaMA" - - # TODO: better logic to determine model name - if params.n_ctx == 4096: - name = "LLaMA v2" + if ARCH == gguf.MODEL_ARCH.LLAMA: + # TODO: better logic to determine model name + if params.n_ctx == 4096: + name = "LLaMA v2" elif params.path_model is not None: name = str(params.path_model.parent).split('/')[-1] @@ -1138,6 +1138,9 @@ def main(args_in: list[str] | None = None) -> None: if np.uint32(1) == np.uint32(1).newbyteorder("<"): # We currently only support Q8_0 output on little endian systems. output_choices.append("q8_0") + + global ARCH + DEFAULT_ARCH_NAME = gguf.MODEL_ARCH_NAMES[ARCH] parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") @@ -1150,6 +1153,7 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") + parser.add_argument("--arch", choices=list(gguf.MODEL_ARCH_NAMES.values()), help=f"model arch name (default: {DEFAULT_ARCH_NAME})", default = DEFAULT_ARCH_NAME) args = parser.parse_args(args_in) if args.dump_single: @@ -1169,6 +1173,14 @@ def main(args_in: list[str] | None = None) -> None: if args.bigendian: endianess = gguf.GGUFEndian.BIG + for arch, name in gguf.MODEL_ARCH_NAMES.items(): + if args.arch == name: + ARCH = arch # modify global ARCH + if ARCH not in gguf.MODEL_ARCH_NAMES.keys(): + raise Exception(f"Invalid arch name: {args.arch}") + + print(f"ARCH = {gguf.MODEL_ARCH_NAMES[ARCH]}") + params = Params.load(model_plus) if params.n_ctx == -1: if args.ctx is None: