Make stablelm conversion script use .safetensors

This commit is contained in:
Galunid 2023-10-18 14:51:50 +02:00
parent 605e701cb4
commit 1ee5cc3076

View file

@ -14,23 +14,13 @@ from typing import Any
import numpy as np
import torch
from transformers import AutoTokenizer # type: ignore[import]
from safetensors import safe_open
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf
def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
num_parts += 1
if num_parts > 0:
print("gguf: found " + str(num_parts) + " model parts")
return num_parts
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file")
parser.add_argument(
@ -82,8 +72,6 @@ if hparams["architectures"][0] != "StableLMEpochForCausalLM":
sys.exit()
# get number of model parts
num_parts = count_model_parts(dir_model)
ARCH=gguf.MODEL_ARCH.STABLELM
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
@ -145,21 +133,16 @@ print(tensor_map)
# tensor info
print("gguf: get tensor metadata")
if num_parts == 0:
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)
part_names = iter(("model.safetensors",))
for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
with ctx as model_part:
for name in model_part.keys():
data = model_part[name]
data = model_part.get_tensor(name)
# we don't need these
if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):