model: dbrx convert permute experts directly torch, log shape

This commit is contained in:
Pierrick HYMBERT 2024-04-08 19:01:44 +02:00
parent f20c04f01f
commit 48909ed2a7

View file

@ -1531,15 +1531,11 @@ class DbrxModel(Model):
for exp_tensor_name in exp_tensor_names.keys():
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
experts = True
expert_permute = exp_tensor_names[exp_tensor_name]
data_torch = data_torch.view(n_expert, n_ff, n_embd).permute(*exp_tensor_names[exp_tensor_name])
break
old_dtype = data_torch.dtype
# View experts tensors as 3D
if experts:
data_torch = data_torch.view(n_expert, n_ff, n_embd)
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
@ -1560,11 +1556,6 @@ class DbrxModel(Model):
n_dims = len(data.shape)
data_dtype = data.dtype
# Transpose experts to the expected llama.cpp format
if experts:
data = data.transpose(expert_permute)
n_dims = len(data.shape)
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
@ -1573,7 +1564,7 @@ class DbrxModel(Model):
if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)