convert-hf : use an ABC for Model again
It seems Protocol can't be used as a statically type-checked ABC, because its subclasses also can't be instantiated. (why did it seem to work?) At least there's still a way to throw an error when forgetting to define the `model_arch` property of any registered Model subclasses.
This commit is contained in:
parent
644c2696d0
commit
ce067af118
1 changed files with 14 additions and 4 deletions
|
@ -8,10 +8,11 @@ import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from abc import ABC
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Protocol, Sequence, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -40,7 +41,7 @@ class SentencePieceTokenTypes(IntEnum):
|
||||||
AnyModel = TypeVar("AnyModel", bound="type[Model]")
|
AnyModel = TypeVar("AnyModel", bound="type[Model]")
|
||||||
|
|
||||||
|
|
||||||
class Model(Protocol):
|
class Model(ABC):
|
||||||
_model_classes: dict[str, type[Model]] = {}
|
_model_classes: dict[str, type[Model]] = {}
|
||||||
|
|
||||||
dir_model: Path
|
dir_model: Path
|
||||||
|
@ -57,6 +58,7 @@ class Model(Protocol):
|
||||||
tensor_map: gguf.TensorNameMap
|
tensor_map: gguf.TensorNameMap
|
||||||
tensors: dict[str, Tensor]
|
tensors: dict[str, Tensor]
|
||||||
|
|
||||||
|
# subclasses should define this!
|
||||||
model_arch: gguf.MODEL_ARCH
|
model_arch: gguf.MODEL_ARCH
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool):
|
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool):
|
||||||
|
@ -163,12 +165,18 @@ class Model(Protocol):
|
||||||
print(f"gguf: file type = {self.ftype}")
|
print(f"gguf: file type = {self.ftype}")
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
del bid # unused
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||||
|
del name, new_name, bid, n_dims # unused
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||||
|
del name, new_name, bid, n_dims # unused
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
|
@ -248,7 +256,7 @@ class Model(Protocol):
|
||||||
return part_names
|
return part_names
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_hparams(dir_model):
|
def load_hparams(dir_model: Path):
|
||||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
@ -258,12 +266,14 @@ class Model(Protocol):
|
||||||
|
|
||||||
def func(modelcls: AnyModel) -> AnyModel:
|
def func(modelcls: AnyModel) -> AnyModel:
|
||||||
for name in names:
|
for name in names:
|
||||||
|
if "model_arch" not in modelcls.__dict__:
|
||||||
|
raise TypeError(f"Missing property 'model_arch' for {modelcls.__name__!r}")
|
||||||
cls._model_classes[name] = modelcls
|
cls._model_classes[name] = modelcls
|
||||||
return modelcls
|
return modelcls
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model_architecture(cls, arch):
|
def from_model_architecture(cls, arch: str) -> type[Model]:
|
||||||
try:
|
try:
|
||||||
return cls._model_classes[arch]
|
return cls._model_classes[arch]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue