bitnet : pad tensors to 256
This commit is contained in:
parent
569a03ed97
commit
e9f2abfc8c
2 changed files with 43 additions and 17 deletions
|
@ -1423,6 +1423,20 @@ class BitnetModel(Model):
|
|||
"o_proj.weight")):
|
||||
data_torch = self.weight_quant(data_torch)
|
||||
|
||||
# pad 1D tensors
|
||||
# TODO: is padding with 0s an invariant, or do we also need some scaling factor?
|
||||
if name.endswith(("input_layernorm.weight", "post_attention_layernorm.weight", "model.norm.weight")):
|
||||
data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(0) % 256), mode='constant', value=0)
|
||||
logger.info(f"pad {name} to {data_torch.size()}")
|
||||
|
||||
# pad 2D tensors
|
||||
# TODO: double-check that this is the correct way to pad the rows
|
||||
if name.endswith(("embed_tokens.weight", "q_proj.weight", "k_proj.weight", "v_proj.weight",
|
||||
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
|
||||
"o_proj.weight")):
|
||||
data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(1) % 256), mode='constant', value=0)
|
||||
logger.info(f"pad {name} to {data_torch.size()}")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue