Make convert script with pytorch files
This commit is contained in:
parent
51b3b56c08
commit
a00bb06c43
1 changed files with 43 additions and 3 deletions
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -20,6 +21,16 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
|
def count_model_parts(dir_model: Path, prefix: str) -> int:
|
||||||
|
num_parts = 0
|
||||||
|
for filename in os.listdir(dir_model):
|
||||||
|
if filename.startswith(prefix):
|
||||||
|
num_parts += 1
|
||||||
|
|
||||||
|
if num_parts > 0:
|
||||||
|
print("gguf: found " + str(num_parts) + " model parts")
|
||||||
|
return num_parts
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file")
|
parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file")
|
||||||
|
@ -141,16 +152,45 @@ tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
|
||||||
# tensor info
|
# tensor info
|
||||||
print("gguf: get tensor metadata")
|
print("gguf: get tensor metadata")
|
||||||
|
|
||||||
|
# get number of model parts
|
||||||
|
num_parts = count_model_parts(dir_model, "model-00")
|
||||||
|
if num_parts:
|
||||||
|
is_safetensors = True
|
||||||
|
from safetensors import safe_open
|
||||||
|
else:
|
||||||
|
if count_model_parts(dir_model, "model.safetensors") > 0:
|
||||||
|
is_safetensors = True
|
||||||
|
num_parts = 0
|
||||||
|
else:
|
||||||
|
is_safetensors = False
|
||||||
|
num_parts = count_model_parts(dir_model, "pytorch_model-")
|
||||||
|
|
||||||
|
if is_safetensors and num_parts == 0:
|
||||||
part_names = iter(("model.safetensors",))
|
part_names = iter(("model.safetensors",))
|
||||||
|
elif num_parts == 0:
|
||||||
|
part_names = iter(("pytorch_model.bin",))
|
||||||
|
elif is_safetensors:
|
||||||
|
part_names = (
|
||||||
|
f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
part_names = (
|
||||||
|
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
for part_name in part_names:
|
for part_name in part_names:
|
||||||
if args.vocab_only:
|
if args.vocab_only:
|
||||||
break
|
break
|
||||||
print("gguf: loading model part '" + part_name + "'")
|
print("gguf: loading model part '" + part_name + "'")
|
||||||
|
if is_safetensors:
|
||||||
ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
|
ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
|
||||||
|
else:
|
||||||
|
ctx = contextlib.nullcontext(torch.load(dir_model / part_name, map_location="cpu"))
|
||||||
|
|
||||||
with ctx as model_part:
|
with ctx as model_part:
|
||||||
for name in model_part.keys():
|
for name in model_part.keys():
|
||||||
data = model_part.get_tensor(name)
|
data = model_part.get_tensor(name) if is_safetensors else model_part[name]
|
||||||
|
|
||||||
# we don't need these
|
# we don't need these
|
||||||
if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
|
if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue