override load_hparams

This commit is contained in:
JustinLin610 2024-04-17 01:53:56 +08:00
parent 6cf0d467f8
commit f22900b30c

View file

@ -195,12 +195,8 @@ class Model(ABC):
@staticmethod
def load_hparams(dir_model):
with open(dir_model / "config.json", "r", encoding="utf-8") as f1, \
open(dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f2:
hparams = json.load(f1)
hparams.update(json.load(f2))
return hparams
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@ -1713,8 +1709,20 @@ class QwenModel(Model):
class Qwen2Model(Model):
model_arch = gguf.MODEL_ARCH.QWEN2
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool):
super().__init__(dir_model, ftype, fname_out, is_big_endian, use_temp_file)
self.hparams = Qwen2Model.load_hparams(dir_model)
@staticmethod
def load_hparams(dir_model):
with open(dir_model / "config.json", "r", encoding="utf-8") as f1, \
open(dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f2:
hparams = json.load(f1)
hparams.update(json.load(f2))
return hparams
def set_vocab(self):
print(f'Tokenizer class: {self.hparams.get("tokenizer_class")}')
if self.hparams.get("tokenizer_class") == "PreTrainedTokenizerFast":
self._set_vocab_sentencepiece()
else: