RWKV 6: Fix error in ggml_cuda_op_bin_bcast
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
ff3d22655b
commit
a20a94f566
1 changed files with 3 additions and 0 deletions
|
@ -3007,6 +3007,9 @@ class Rwkv6Model(Model):
|
|||
if new_name.endswith("time_mix_w2.weight"):
|
||||
data_torch = data_torch.permute(0, 2, 1)
|
||||
|
||||
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
|
||||
data_torch = data_torch.squeeze()
|
||||
|
||||
rescale_every_n_layers = self.hparams["rescale_every"]
|
||||
if rescale_every_n_layers > 0:
|
||||
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue