From 13f4cf70dbafe0d652ed9f1a3961fb40428d56bc Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 2 May 2024 15:50:21 -0400 Subject: [PATCH] convert-hf : use a plain class for Model, and forbid direct instantiation There are no abstract methods used anyway, so using ABC isn't really necessary. --- convert-hf-to-gguf.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a40924408..3191afcfd 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -8,7 +8,6 @@ import json import os import re import sys -from abc import ABC from enum import IntEnum from pathlib import Path from hashlib import sha256 @@ -41,7 +40,7 @@ class SentencePieceTokenTypes(IntEnum): AnyModel = TypeVar("AnyModel", bound="type[Model]") -class Model(ABC): +class Model: _model_classes: dict[str, type[Model]] = {} dir_model: Path @@ -62,6 +61,8 @@ class Model(ABC): model_arch: gguf.MODEL_ARCH def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool): + if self.__class__ == Model: + raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated") self.dir_model = dir_model self.ftype = ftype self.fname_out = fname_out @@ -78,6 +79,13 @@ class Model(ABC): self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensors = dict(self.get_tensors()) + @classmethod + def __init_subclass__(cls): + # can't use an abstract property, because overriding it without type errors + # would require using decorated functions instead of simply defining the property + if "model_arch" not in cls.__dict__: + raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}") + def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: key = next((k for k in keys if k in self.hparams), None) if key is not None: @@ -266,8 +274,6 @@ class Model(ABC): def func(modelcls: AnyModel) -> AnyModel: for name in names: - if "model_arch" not in modelcls.__dict__: - raise TypeError(f"Missing property 'model_arch' for {modelcls.__name__!r}") cls._model_classes[name] = modelcls return modelcls return func