From ce067af1184ae4ca7fa326f136dc6364e98251be Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 2 May 2024 15:00:36 -0400 Subject: [PATCH] convert-hf : use an ABC for Model again It seems Protocol can't be used as a statically type-checked ABC, because its subclasses also can't be instantiated. (why did it seem to work?) At least there's still a way to throw an error when forgetting to define the `model_arch` property of any registered Model subclasses. --- convert-hf-to-gguf.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 91f9d127b..a40924408 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -8,10 +8,11 @@ import json import os import re import sys +from abc import ABC from enum import IntEnum from pathlib import Path from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Protocol, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast import numpy as np import torch @@ -40,7 +41,7 @@ class SentencePieceTokenTypes(IntEnum): AnyModel = TypeVar("AnyModel", bound="type[Model]") -class Model(Protocol): +class Model(ABC): _model_classes: dict[str, type[Model]] = {} dir_model: Path @@ -57,6 +58,7 @@ class Model(Protocol): tensor_map: gguf.TensorNameMap tensors: dict[str, Tensor] + # subclasses should define this! model_arch: gguf.MODEL_ARCH def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool): @@ -163,12 +165,18 @@ class Model(Protocol): print(f"gguf: file type = {self.ftype}") def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + return [(self.map_tensor_name(name), data_torch)] def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + del name, new_name, bid, n_dims # unused + return False def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + del name, new_name, bid, n_dims # unused + return False def write_tensors(self): @@ -248,7 +256,7 @@ class Model(Protocol): return part_names @staticmethod - def load_hparams(dir_model): + def load_hparams(dir_model: Path): with open(dir_model / "config.json", "r", encoding="utf-8") as f: return json.load(f) @@ -258,12 +266,14 @@ class Model(Protocol): def func(modelcls: AnyModel) -> AnyModel: for name in names: + if "model_arch" not in modelcls.__dict__: + raise TypeError(f"Missing property 'model_arch' for {modelcls.__name__!r}") cls._model_classes[name] = modelcls return modelcls return func @classmethod - def from_model_architecture(cls, arch): + def from_model_architecture(cls, arch: str) -> type[Model]: try: return cls._model_classes[arch] except KeyError: