gemma2: add sliding window mask (#8227)
* gemma2: add sliding window mask * fix data_swa uninitialized * better naming * add co-author Co-authored-by: Arlo Phoenix <arlo-phoenix@users.noreply.github.com> * replace list with single tensor * update * llama : minor styling * convert : add sanity check for query_pre_attn_scalar * fix small typo in README --------- Co-authored-by: Arlo Phoenix <arlo-phoenix@users.noreply.github.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
0ddeff1023
commit
49122a873f
5 changed files with 79 additions and 32 deletions
|
@ -2369,6 +2369,12 @@ class Gemma2Model(Model):
|
|||
self.gguf_writer.add_final_logit_softcapping(
|
||||
self.hparams["final_logit_softcapping"]
|
||||
)
|
||||
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||
|
||||
# sanity check
|
||||
attn_scalar = self.hparams["query_pre_attn_scalar"]
|
||||
if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
|
||||
raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unusem
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue