diff --git a/convert-stablelm-hf-to-gguf.py b/convert-stablelm-hf-to-gguf.py index 3bd8bdda0..ad2727973 100755 --- a/convert-stablelm-hf-to-gguf.py +++ b/convert-stablelm-hf-to-gguf.py @@ -3,6 +3,7 @@ from __future__ import annotations +import contextlib import argparse import json import os @@ -20,6 +21,16 @@ if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf')) import gguf +def count_model_parts(dir_model: Path, prefix: str) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.startswith(prefix): + num_parts += 1 + + if num_parts > 0: + print("gguf: found " + str(num_parts) + " model parts") + return num_parts + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file") @@ -141,16 +152,45 @@ tensor_map = gguf.get_tensor_name_map(ARCH,block_count) # tensor info print("gguf: get tensor metadata") -part_names = iter(("model.safetensors",)) +# get number of model parts +num_parts = count_model_parts(dir_model, "model-00") +if num_parts: + is_safetensors = True + from safetensors import safe_open +else: + if count_model_parts(dir_model, "model.safetensors") > 0: + is_safetensors = True + num_parts = 0 + else: + is_safetensors = False + num_parts = count_model_parts(dir_model, "pytorch_model-") + +if is_safetensors and num_parts == 0: + part_names = iter(("model.safetensors",)) +elif num_parts == 0: + part_names = iter(("pytorch_model.bin",)) +elif is_safetensors: + part_names = ( + f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1) + ) +else: + part_names = ( + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) + ) + for part_name in part_names: if args.vocab_only: break print("gguf: loading model part '" + part_name + "'") - ctx = safe_open(dir_model / part_name, framework="pt", device="cpu") + if is_safetensors: + ctx = safe_open(dir_model / part_name, framework="pt", device="cpu") + else: + ctx = contextlib.nullcontext(torch.load(dir_model / part_name, map_location="cpu")) + with ctx as model_part: for name in model_part.keys(): - data = model_part.get_tensor(name) + data = model_part.get_tensor(name) if is_safetensors else model_part[name] # we don't need these if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):