Add arch argument for convert.py
Signed-off-by: caiyesd <caiyesd@gmail.com>
This commit is contained in:
parent
8e672efe63
commit
625954985c
1 changed files with 16 additions and 4 deletions
20
convert.py
20
convert.py
|
@ -817,10 +817,10 @@ class OutputFile:
|
|||
|
||||
def add_meta_arch(self, params: Params) -> None:
|
||||
name = "LLaMA"
|
||||
|
||||
# TODO: better logic to determine model name
|
||||
if params.n_ctx == 4096:
|
||||
name = "LLaMA v2"
|
||||
if ARCH == gguf.MODEL_ARCH.LLAMA:
|
||||
# TODO: better logic to determine model name
|
||||
if params.n_ctx == 4096:
|
||||
name = "LLaMA v2"
|
||||
elif params.path_model is not None:
|
||||
name = str(params.path_model.parent).split('/')[-1]
|
||||
|
||||
|
@ -1138,6 +1138,9 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
|
||||
# We currently only support Q8_0 output on little endian systems.
|
||||
output_choices.append("q8_0")
|
||||
|
||||
global ARCH
|
||||
DEFAULT_ARCH_NAME = gguf.MODEL_ARCH_NAMES[ARCH]
|
||||
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
|
||||
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
||||
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
||||
|
@ -1150,6 +1153,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
|
||||
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
|
||||
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
|
||||
parser.add_argument("--arch", choices=list(gguf.MODEL_ARCH_NAMES.values()), help=f"model arch name (default: {DEFAULT_ARCH_NAME})", default = DEFAULT_ARCH_NAME)
|
||||
|
||||
args = parser.parse_args(args_in)
|
||||
if args.dump_single:
|
||||
|
@ -1169,6 +1173,14 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
if args.bigendian:
|
||||
endianess = gguf.GGUFEndian.BIG
|
||||
|
||||
for arch, name in gguf.MODEL_ARCH_NAMES.items():
|
||||
if args.arch == name:
|
||||
ARCH = arch # modify global ARCH
|
||||
if ARCH not in gguf.MODEL_ARCH_NAMES.keys():
|
||||
raise Exception(f"Invalid arch name: {args.arch}")
|
||||
|
||||
print(f"ARCH = {gguf.MODEL_ARCH_NAMES[ARCH]}")
|
||||
|
||||
params = Params.load(model_plus)
|
||||
if params.n_ctx == -1:
|
||||
if args.ctx is None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue