gemma2: add sliding window mask (#8227)

* gemma2: add sliding window mask

* fix data_swa uninitialized

* better naming

* add co-author

Co-authored-by: Arlo Phoenix <arlo-phoenix@users.noreply.github.com>

* replace list with single tensor

* update

* llama : minor styling

* convert : add sanity check for query_pre_attn_scalar

* fix small typo in README

---------

Co-authored-by: Arlo Phoenix <arlo-phoenix@users.noreply.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan Son Nguyen 2024-07-01 18:48:34 +02:00 committed by GitHub
parent 0ddeff1023
commit 49122a873f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 79 additions and 32 deletions

View file

@ -66,6 +66,7 @@ class Keys:
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"

View file

@ -552,6 +552,9 @@ class GGUFWriter:
def add_relative_attn_buckets_count(self, value: int) -> None:
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)