convert-hf : use an ABC for Model again

It seems Protocol can't be used as a statically type-checked ABC,
because its subclasses also can't be instantiated. (why did it seem to work?)

At least there's still a way to throw an error when forgetting to define
the `model_arch` property of any registered Model subclasses.
This commit is contained in:
Francis Couture-Harpin 2024-05-02 15:00:36 -04:00
parent 644c2696d0
commit ce067af118

View file

@ -8,10 +8,11 @@ import json
import os import os
import re import re
import sys import sys
from abc import ABC
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from hashlib import sha256 from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Protocol, Sequence, TypeVar, cast from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
import numpy as np import numpy as np
import torch import torch
@ -40,7 +41,7 @@ class SentencePieceTokenTypes(IntEnum):
AnyModel = TypeVar("AnyModel", bound="type[Model]") AnyModel = TypeVar("AnyModel", bound="type[Model]")
class Model(Protocol): class Model(ABC):
_model_classes: dict[str, type[Model]] = {} _model_classes: dict[str, type[Model]] = {}
dir_model: Path dir_model: Path
@ -57,6 +58,7 @@ class Model(Protocol):
tensor_map: gguf.TensorNameMap tensor_map: gguf.TensorNameMap
tensors: dict[str, Tensor] tensors: dict[str, Tensor]
# subclasses should define this!
model_arch: gguf.MODEL_ARCH model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool): def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool):
@ -163,12 +165,18 @@ class Model(Protocol):
print(f"gguf: file type = {self.ftype}") print(f"gguf: file type = {self.ftype}")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del name, new_name, bid, n_dims # unused
return False return False
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del name, new_name, bid, n_dims # unused
return False return False
def write_tensors(self): def write_tensors(self):
@ -248,7 +256,7 @@ class Model(Protocol):
return part_names return part_names
@staticmethod @staticmethod
def load_hparams(dir_model): def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f: with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
@ -258,12 +266,14 @@ class Model(Protocol):
def func(modelcls: AnyModel) -> AnyModel: def func(modelcls: AnyModel) -> AnyModel:
for name in names: for name in names:
if "model_arch" not in modelcls.__dict__:
raise TypeError(f"Missing property 'model_arch' for {modelcls.__name__!r}")
cls._model_classes[name] = modelcls cls._model_classes[name] = modelcls
return modelcls return modelcls
return func return func
@classmethod @classmethod
def from_model_architecture(cls, arch): def from_model_architecture(cls, arch: str) -> type[Model]:
try: try:
return cls._model_classes[arch] return cls._model_classes[arch]
except KeyError: except KeyError: