convert-hf : use a plain class for Model, and forbid direct instantiation

There are no abstract methods used anyway,
so using ABC isn't really necessary.
This commit is contained in:
Francis Couture-Harpin 2024-05-02 15:50:21 -04:00
parent ce067af118
commit 13f4cf70db

View file

@ -8,7 +8,6 @@ import json
import os import os
import re import re
import sys import sys
from abc import ABC
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from hashlib import sha256 from hashlib import sha256
@ -41,7 +40,7 @@ class SentencePieceTokenTypes(IntEnum):
AnyModel = TypeVar("AnyModel", bound="type[Model]") AnyModel = TypeVar("AnyModel", bound="type[Model]")
class Model(ABC): class Model:
_model_classes: dict[str, type[Model]] = {} _model_classes: dict[str, type[Model]] = {}
dir_model: Path dir_model: Path
@ -62,6 +61,8 @@ class Model(ABC):
model_arch: gguf.MODEL_ARCH model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool): def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool):
if self.__class__ == Model:
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
self.dir_model = dir_model self.dir_model = dir_model
self.ftype = ftype self.ftype = ftype
self.fname_out = fname_out self.fname_out = fname_out
@ -78,6 +79,13 @@ class Model(ABC):
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensors = dict(self.get_tensors()) self.tensors = dict(self.get_tensors())
@classmethod
def __init_subclass__(cls):
# can't use an abstract property, because overriding it without type errors
# would require using decorated functions instead of simply defining the property
if "model_arch" not in cls.__dict__:
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
key = next((k for k in keys if k in self.hparams), None) key = next((k for k in keys if k in self.hparams), None)
if key is not None: if key is not None:
@ -266,8 +274,6 @@ class Model(ABC):
def func(modelcls: AnyModel) -> AnyModel: def func(modelcls: AnyModel) -> AnyModel:
for name in names: for name in names:
if "model_arch" not in modelcls.__dict__:
raise TypeError(f"Missing property 'model_arch' for {modelcls.__name__!r}")
cls._model_classes[name] = modelcls cls._model_classes[name] = modelcls
return modelcls return modelcls
return func return func