This commit is contained in:
namtranase 2023-12-25 17:13:50 +07:00
commit 44f4ce2272
3 changed files with 3 additions and 6 deletions

View file

@ -189,10 +189,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif (
isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm))
or "rmsnorm" in str(prev_op.__class__).lower()
):
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)) or "rmsnorm" in str(prev_op.__class__).lower():
scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
new_module = ScaledActivation(prev_op, scales)