convert-hf : small fix for mypy
This commit is contained in:
parent
16ede02a47
commit
6dba2de027
1 changed files with 7 additions and 5 deletions
|
@ -1085,12 +1085,14 @@ class LlamaModel(Model):
|
|||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
if name.endswith(("q_proj.weight")):
|
||||
data_torch = permute(data_torch, n_head, n_head)
|
||||
if name.endswith(("k_proj.weight")):
|
||||
data_torch = permute(data_torch, n_head, n_kv_head)
|
||||
data = data_torch.numpy()
|
||||
|
||||
data = data_torch.squeeze().numpy()
|
||||
if name.endswith("q_proj.weight"):
|
||||
data = permute(data, n_head, n_head)
|
||||
if name.endswith("k_proj.weight"):
|
||||
data = permute(data, n_head, n_kv_head)
|
||||
|
||||
data = data.squeeze()
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue