bitnet : pad tensors to 256

This commit is contained in:
Georgi Gerganov 2024-06-15 19:01:03 +03:00
parent 569a03ed97
commit e9f2abfc8c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 43 additions and 17 deletions

View file

@ -1423,6 +1423,20 @@ class BitnetModel(Model):
"o_proj.weight")):
data_torch = self.weight_quant(data_torch)
# pad 1D tensors
# TODO: is padding with 0s an invariant, or do we also need some scaling factor?
if name.endswith(("input_layernorm.weight", "post_attention_layernorm.weight", "model.norm.weight")):
data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(0) % 256), mode='constant', value=0)
logger.info(f"pad {name} to {data_torch.size()}")
# pad 2D tensors
# TODO: double-check that this is the correct way to pad the rows
if name.endswith(("embed_tokens.weight", "q_proj.weight", "k_proj.weight", "v_proj.weight",
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
"o_proj.weight")):
data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(1) % 256), mode='constant', value=0)
logger.info(f"pad {name} to {data_torch.size()}")
return [(self.map_tensor_name(name), data_torch)]