clean up previous hack

This commit is contained in:
Green Sky 2023-06-21 19:52:40 +02:00
parent 72397fbe63
commit 0141e6395c
No known key found for this signature in database

View file

@ -649,18 +649,11 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
out["norm.weight"] = model["model.norm.weight"]
out["output.weight"] = model["lm_head.weight"]
n_embd = out["tok_embeddings.weight"].shape[1]
n_head = n_embd // 128 # guessed
if "model.layers.0.self_attn.rotary_emb.inv_freq" in model:
dim_inv_freq = model["model.layers.0.self_attn.rotary_emb.inv_freq"].shape[0]
n_head = n_embd // (dim_inv_freq * 2)
for i in itertools.count():
if f"model.layers.{i}.self_attn.q_proj.weight" not in model:
break
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], n_head)
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], n_head)
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]