add causal attention gguf key

This commit is contained in:
Douglas Hanley 2024-02-08 16:41:05 -06:00
parent e3efcf13c8
commit 96d37f8d55
3 changed files with 6 additions and 2 deletions

View file

@ -1647,7 +1647,7 @@ class BertModel(Model):
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_bool("bert.attention.causal", False)
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_file_type(self.ftype)
def set_vocab(self):
@ -1662,7 +1662,7 @@ class BertModel(Model):
# we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings
n_token_types = len(set(toktypes))
self.gguf_writer.add_uint32("tokenizer.ggml.token_type_count", n_token_types)
self.gguf_writer.add_token_type_count(n_token_types)
# convert to phantom space vocab
def phantom(tok, typ):

View file

@ -50,6 +50,7 @@ class Keys:
VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
CAUSAL = "{arch}.attention.causal"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"

View file

@ -357,6 +357,9 @@ class GGUFWriter:
def add_layer_norm_rms_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)