Add arch argument for convert.py

Signed-off-by: caiyesd <caiyesd@gmail.com>
This commit is contained in:
caiyesd 2023-11-23 10:30:43 +08:00
parent 8e672efe63
commit 625954985c

View file

@ -817,10 +817,10 @@ class OutputFile:
def add_meta_arch(self, params: Params) -> None: def add_meta_arch(self, params: Params) -> None:
name = "LLaMA" name = "LLaMA"
if ARCH == gguf.MODEL_ARCH.LLAMA:
# TODO: better logic to determine model name # TODO: better logic to determine model name
if params.n_ctx == 4096: if params.n_ctx == 4096:
name = "LLaMA v2" name = "LLaMA v2"
elif params.path_model is not None: elif params.path_model is not None:
name = str(params.path_model.parent).split('/')[-1] name = str(params.path_model.parent).split('/')[-1]
@ -1138,6 +1138,9 @@ def main(args_in: list[str] | None = None) -> None:
if np.uint32(1) == np.uint32(1).newbyteorder("<"): if np.uint32(1) == np.uint32(1).newbyteorder("<"):
# We currently only support Q8_0 output on little endian systems. # We currently only support Q8_0 output on little endian systems.
output_choices.append("q8_0") output_choices.append("q8_0")
global ARCH
DEFAULT_ARCH_NAME = gguf.MODEL_ARCH_NAMES[ARCH]
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
@ -1150,6 +1153,7 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
parser.add_argument("--arch", choices=list(gguf.MODEL_ARCH_NAMES.values()), help=f"model arch name (default: {DEFAULT_ARCH_NAME})", default = DEFAULT_ARCH_NAME)
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
if args.dump_single: if args.dump_single:
@ -1169,6 +1173,14 @@ def main(args_in: list[str] | None = None) -> None:
if args.bigendian: if args.bigendian:
endianess = gguf.GGUFEndian.BIG endianess = gguf.GGUFEndian.BIG
for arch, name in gguf.MODEL_ARCH_NAMES.items():
if args.arch == name:
ARCH = arch # modify global ARCH
if ARCH not in gguf.MODEL_ARCH_NAMES.keys():
raise Exception(f"Invalid arch name: {args.arch}")
print(f"ARCH = {gguf.MODEL_ARCH_NAMES[ARCH]}")
params = Params.load(model_plus) params = Params.load(model_plus)
if params.n_ctx == -1: if params.n_ctx == -1:
if args.ctx is None: if args.ctx is None: