Merge branch 'master' into compilade/faster-lazy-safetensors

This commit is contained in:
Francis Couture-Harpin 2024-07-15 15:24:25 -04:00
commit 2a49a68d70
25 changed files with 1531 additions and 720 deletions

View file

@ -2271,13 +2271,6 @@ class InternLM2Model(Model):
special_vocab.add_to_gguf(self.gguf_writer)
def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
def set_gguf_parameters(self):
self.gguf_writer.add_name("InternLM2")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
@ -2297,26 +2290,22 @@ class InternLM2Model(Model):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
num_heads = self.hparams["num_attention_heads"]
num_kv_heads = self.hparams["num_key_value_heads"]
hidden_size = self.hparams["hidden_size"]
n_embd = self.hparams["hidden_size"]
q_per_kv = num_heads // num_kv_heads
head_dim = hidden_size // num_heads
head_dim = n_embd // num_heads
num_groups = num_heads // q_per_kv
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
if re.match(qkv_pattern, name):
bid = re.findall(qkv_pattern, name)[0]
if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
qkv = data_torch
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd))
q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1]
# The model weights of q and k equire additional reshape.
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
# v = rearrange(v, " o g n i -> o (g n i)").T
v = v.reshape((v.shape[0], -1)).T
q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads)
k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads)
v = v.reshape((-1, v.shape[-1]))
return [
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
@ -3620,6 +3609,7 @@ def main() -> None:
small_first_shard=args.no_tensor_first_split)
logger.info("Set model parameters")
model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL)
model_instance.set_gguf_parameters()
logger.info("Set model tokenizer")