convert-hf : remove einops requirement for InternLM2
This commit is contained in:
parent
0c3833286e
commit
98db4347e8
4 changed files with 22 additions and 20 deletions
|
@ -86,7 +86,6 @@ let
|
||||||
# TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime
|
# TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime
|
||||||
llama-python-extra = python3.withPackages (
|
llama-python-extra = python3.withPackages (
|
||||||
ps: [
|
ps: [
|
||||||
ps.einops
|
|
||||||
ps.numpy
|
ps.numpy
|
||||||
ps.sentencepiece
|
ps.sentencepiece
|
||||||
ps.tiktoken
|
ps.tiktoken
|
||||||
|
|
|
@ -1890,16 +1890,18 @@ in chat mode so that the conversation can end normally.")
|
||||||
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
|
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
|
||||||
|
|
||||||
if re.match(qkv_pattern, name):
|
if re.match(qkv_pattern, name):
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
bid = re.findall(qkv_pattern, name)[0]
|
bid = re.findall(qkv_pattern, name)[0]
|
||||||
qkv = data_torch
|
qkv = data_torch
|
||||||
qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
|
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
|
||||||
|
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
|
||||||
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
|
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
|
||||||
# The model weights of q and k equire additional reshape.
|
# The model weights of q and k equire additional reshape.
|
||||||
q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
|
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
|
||||||
k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
|
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
|
||||||
v = rearrange(v, " o g n i -> o (g n i)").T
|
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
|
||||||
|
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
|
||||||
|
# v = rearrange(v, " o g n i -> o (g n i)").T
|
||||||
|
v = v.reshape((v.shape[0], -1)).T
|
||||||
return [
|
return [
|
||||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
|
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
|
||||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
|
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
|
||||||
|
@ -2238,13 +2240,13 @@ class OlmoModel(Model):
|
||||||
class LazyTorchTensor:
|
class LazyTorchTensor:
|
||||||
_meta: Tensor
|
_meta: Tensor
|
||||||
_data: Tensor | None
|
_data: Tensor | None
|
||||||
_args: list[Any]
|
_args: tuple
|
||||||
_func: Callable[[list[Any]], Tensor] | None = None
|
_func: Callable[[tuple], Tensor] | None
|
||||||
|
|
||||||
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: list[Any] | None = None, func: Callable[[list[Any]], Tensor] | None = None):
|
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
|
||||||
self._meta = meta
|
self._meta = meta
|
||||||
self._data = data
|
self._data = data
|
||||||
self._args = args if args is not None else []
|
self._args = args
|
||||||
self._func = func
|
self._func = func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -2266,19 +2268,22 @@ class LazyTorchTensor:
|
||||||
def wrapped_fn(*args, **kwargs):
|
def wrapped_fn(*args, **kwargs):
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
args_list = ([self] if use_self else []) + list(args)
|
args = ((self,) if use_self else ()) + args
|
||||||
|
|
||||||
meta_args = LazyTorchTensor._recurse_apply(args_list, lambda t: t._meta)
|
meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
|
||||||
|
|
||||||
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args_list, func=lambda a: fn(*a, **kwargs))
|
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
|
||||||
return wrapped_fn
|
return wrapped_fn
|
||||||
|
|
||||||
def __getattr__(self, __name: str) -> Any:
|
def __getattr__(self, __name: str) -> Any:
|
||||||
meta_attr = getattr(self._meta, __name)
|
meta_attr = getattr(self._meta, __name)
|
||||||
if not callable(meta_attr):
|
if callable(meta_attr):
|
||||||
return meta_attr
|
|
||||||
else:
|
|
||||||
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
|
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
|
||||||
|
elif isinstance(meta_attr, torch.Tensor):
|
||||||
|
# for things like self.T
|
||||||
|
return self._wrap_fn(lambda s: getattr(s, __name))(self)
|
||||||
|
else:
|
||||||
|
return meta_attr
|
||||||
|
|
||||||
_dtype_map: dict[torch.dtype, type] = {
|
_dtype_map: dict[torch.dtype, type] = {
|
||||||
torch.float16: np.float16,
|
torch.float16: np.float16,
|
||||||
|
@ -2295,7 +2300,7 @@ class LazyTorchTensor:
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_eager(t: list[Tensor | LazyTorchTensor]) -> list[Tensor]: ...
|
def to_eager(t: tuple) -> tuple: ...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_eager(t: Any) -> Any:
|
def to_eager(t: Any) -> Any:
|
||||||
|
|
|
@ -1,3 +1,2 @@
|
||||||
-r ./requirements-convert.txt
|
-r ./requirements-convert.txt
|
||||||
torch~=2.1.1
|
torch~=2.1.1
|
||||||
einops~=0.7.0
|
|
||||||
|
|
|
@ -1,3 +1,2 @@
|
||||||
-r ./requirements-convert.txt
|
-r ./requirements-convert.txt
|
||||||
torch~=2.1.1
|
torch~=2.1.1
|
||||||
einops~=0.7.0
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue