RWKV 6: Fix error in ggml_cuda_op_bin_bcast

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-12-20 15:03:34 +08:00
parent ff3d22655b
commit a20a94f566

View file

@ -3007,6 +3007,9 @@ class Rwkv6Model(Model):
if new_name.endswith("time_mix_w2.weight"):
data_torch = data_torch.permute(0, 2, 1)
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
data_torch = data_torch.squeeze()
rescale_every_n_layers = self.hparams["rescale_every"]
if rescale_every_n_layers > 0:
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):