convert-hf : small fix for mypy

This commit is contained in:
Jared Van Bortel 2024-03-28 18:01:05 -04:00
parent 16ede02a47
commit 6dba2de027

View file

@ -1085,12 +1085,14 @@ class LlamaModel(Model):
if data_torch.dtype not in (torch.float16, torch.float32): if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32) data_torch = data_torch.to(torch.float32)
if name.endswith(("q_proj.weight")): data = data_torch.numpy()
data_torch = permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight")):
data_torch = permute(data_torch, n_head, n_kv_head)
data = data_torch.squeeze().numpy() if name.endswith("q_proj.weight"):
data = permute(data, n_head, n_head)
if name.endswith("k_proj.weight"):
data = permute(data, n_head, n_kv_head)
data = data.squeeze()
# map tensor names # map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))