9B - query_pre_attn_scalar = 256 not 224

See 03e657582d

Gemma 9b should use 256 and not 224 (self.config.hidden_size // self.config.num_attention_heads)
This commit is contained in:
Daniel Han 2024-07-10 22:02:21 -07:00 committed by GitHub
parent 278d0e1846
commit 3a2e615f68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2474,11 +2474,6 @@ class Gemma2Model(Model):
) )
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
# sanity check
attn_scalar = self.hparams["query_pre_attn_scalar"]
if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused