Add arch argument for convert.py

Signed-off-by: caiyesd <caiyesd@gmail.com>
This commit is contained in:
caiyesd 2023-11-23 10:30:43 +08:00
parent 8e672efe63
commit 625954985c

View file

@ -817,7 +817,7 @@ class OutputFile:
def add_meta_arch(self, params: Params) -> None:
name = "LLaMA"
if ARCH == gguf.MODEL_ARCH.LLAMA:
# TODO: better logic to determine model name
if params.n_ctx == 4096:
name = "LLaMA v2"
@ -1138,6 +1138,9 @@ def main(args_in: list[str] | None = None) -> None:
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
# We currently only support Q8_0 output on little endian systems.
output_choices.append("q8_0")
global ARCH
DEFAULT_ARCH_NAME = gguf.MODEL_ARCH_NAMES[ARCH]
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
@ -1150,6 +1153,7 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
parser.add_argument("--arch", choices=list(gguf.MODEL_ARCH_NAMES.values()), help=f"model arch name (default: {DEFAULT_ARCH_NAME})", default = DEFAULT_ARCH_NAME)
args = parser.parse_args(args_in)
if args.dump_single:
@ -1169,6 +1173,14 @@ def main(args_in: list[str] | None = None) -> None:
if args.bigendian:
endianess = gguf.GGUFEndian.BIG
for arch, name in gguf.MODEL_ARCH_NAMES.items():
if args.arch == name:
ARCH = arch # modify global ARCH
if ARCH not in gguf.MODEL_ARCH_NAMES.keys():
raise Exception(f"Invalid arch name: {args.arch}")
print(f"ARCH = {gguf.MODEL_ARCH_NAMES[ARCH]}")
params = Params.load(model_plus)
if params.n_ctx == -1:
if args.ctx is None: