Small refactor

This commit is contained in:
Galunid 2023-10-31 16:23:53 +01:00
parent c94df09732
commit 235acc18cd
2 changed files with 23 additions and 27 deletions

View file

@ -28,41 +28,25 @@ else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
print("gguf: loading model " + dir_model.name)
print(f"Loading model: {dir_model.name}")
hparams = model.Model.load_hparams(dir_model)
model_class = model.Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype, fname_out)
print("gguf: get model metadata")
print("Set model parameters")
model_instance.set_gguf_parameters()
# TOKENIZATION
print("gguf: get tokenizer metadata")
print("gguf: get gpt2 tokenizer vocab")
print("Set model tokenizer")
model_instance.set_vocab()
# write model
if not args.vocab_only:
model_instance.write_tensors()
print("gguf: write header")
model_instance.gguf_writer.write_header_to_file()
print("gguf: write metadata")
model_instance.gguf_writer.write_kv_data_to_file()
print("gguf: write tensors")
model_instance.gguf_writer.write_tensors_to_file()
print(f"Exporting model to '{fname_out}'")
model_instance.write()
else:
print("gguf: write header")
model_instance.gguf_writer.write_header_to_file()
print("gguf: write metadata")
model_instance.gguf_writer.write_kv_data_to_file()
print(f"Exporting model vocab to '{fname_out}'")
model_instance.write_vocab()
model_instance.gguf_writer.close()
print(f"gguf: model successfully exported to '{fname_out}'")
print(f"Model successfully exported to '{fname_out}'")
print("")

View file

@ -9,7 +9,7 @@ import numpy as np
from enum import Enum
from pathlib import Path
from typing import TypeAlias, Any
from typing import TypeAlias, Any, Generator
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
@ -48,7 +48,7 @@ class Model:
return ("pytorch_model.bin",)
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
def _get_model_architecture(self):
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
arch = self.hparams["architectures"][0]
if arch == "GPTNeoXForCausalLM":
return gguf.MODEL_ARCH.GPTNEOX
@ -160,7 +160,7 @@ class Model:
def set_vocab(self):
self._set_vocab_gpt2()
def get_tensors(self):
def get_tensors(self) -> Generator[str, Any]:
for part_name in self.part_names:
print("gguf: loading model part '" + part_name + "'")
if self.is_safetensors:
@ -230,6 +230,18 @@ class Model:
self.gguf_writer.add_tensor(new_name, data)
def write(self):
self.write_tensors()
self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file()
self.gguf_writer.close()
def write_vocab(self):
self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close()
@staticmethod
def count_model_parts(dir_model: Path, prefix: str) -> int:
num_parts = 0