convert : add sanity check for query_pre_attn_scalar

This commit is contained in:
Georgi Gerganov 2024-07-01 18:38:24 +03:00
parent ce711f6eae
commit 7dc9cbf03f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2371,6 +2371,11 @@ class Gemma2Model(Model):
)
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
# sanity check
attn_scalar = self.hparams["query_pre_attn_scalar"]
if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unusem