convert-hf : remove einops requirement for InternLM2
This commit is contained in:
parent
0c3833286e
commit
98db4347e8
4 changed files with 22 additions and 20 deletions
|
@ -1890,16 +1890,18 @@ in chat mode so that the conversation can end normally.")
|
|||
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
|
||||
|
||||
if re.match(qkv_pattern, name):
|
||||
from einops import rearrange
|
||||
|
||||
bid = re.findall(qkv_pattern, name)[0]
|
||||
qkv = data_torch
|
||||
qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
|
||||
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
|
||||
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
|
||||
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
|
||||
# The model weights of q and k equire additional reshape.
|
||||
q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
|
||||
k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
|
||||
v = rearrange(v, " o g n i -> o (g n i)").T
|
||||
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
|
||||
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
|
||||
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
|
||||
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
|
||||
# v = rearrange(v, " o g n i -> o (g n i)").T
|
||||
v = v.reshape((v.shape[0], -1)).T
|
||||
return [
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
|
||||
|
@ -2238,13 +2240,13 @@ class OlmoModel(Model):
|
|||
class LazyTorchTensor:
|
||||
_meta: Tensor
|
||||
_data: Tensor | None
|
||||
_args: list[Any]
|
||||
_func: Callable[[list[Any]], Tensor] | None = None
|
||||
_args: tuple
|
||||
_func: Callable[[tuple], Tensor] | None
|
||||
|
||||
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: list[Any] | None = None, func: Callable[[list[Any]], Tensor] | None = None):
|
||||
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
|
||||
self._meta = meta
|
||||
self._data = data
|
||||
self._args = args if args is not None else []
|
||||
self._args = args
|
||||
self._func = func
|
||||
|
||||
@staticmethod
|
||||
|
@ -2266,19 +2268,22 @@ class LazyTorchTensor:
|
|||
def wrapped_fn(*args, **kwargs):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
args_list = ([self] if use_self else []) + list(args)
|
||||
args = ((self,) if use_self else ()) + args
|
||||
|
||||
meta_args = LazyTorchTensor._recurse_apply(args_list, lambda t: t._meta)
|
||||
meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
|
||||
|
||||
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args_list, func=lambda a: fn(*a, **kwargs))
|
||||
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
|
||||
return wrapped_fn
|
||||
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
meta_attr = getattr(self._meta, __name)
|
||||
if not callable(meta_attr):
|
||||
return meta_attr
|
||||
else:
|
||||
if callable(meta_attr):
|
||||
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
|
||||
elif isinstance(meta_attr, torch.Tensor):
|
||||
# for things like self.T
|
||||
return self._wrap_fn(lambda s: getattr(s, __name))(self)
|
||||
else:
|
||||
return meta_attr
|
||||
|
||||
_dtype_map: dict[torch.dtype, type] = {
|
||||
torch.float16: np.float16,
|
||||
|
@ -2295,7 +2300,7 @@ class LazyTorchTensor:
|
|||
|
||||
@overload
|
||||
@staticmethod
|
||||
def to_eager(t: list[Tensor | LazyTorchTensor]) -> list[Tensor]: ...
|
||||
def to_eager(t: tuple) -> tuple: ...
|
||||
|
||||
@staticmethod
|
||||
def to_eager(t: Any) -> Any:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue