put causal_attn flag in gguf

This commit is contained in:
Douglas Hanley 2024-02-08 13:28:25 -06:00
parent 59c1829b0c
commit 5f1c21d0b6
2 changed files with 7 additions and 2 deletions

View file

@ -1588,6 +1588,7 @@ class BertModel(Model):
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_bool("bert.attention.causal", False)
self.gguf_writer.add_file_type(self.ftype)
def set_vocab(self):

View file

@ -263,6 +263,7 @@ enum llm_kv {
LLM_KV_ATTENTION_VALUE_LENGTH,
LLM_KV_ATTENTION_LAYERNORM_EPS,
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
LLM_KV_ATTENTION_CAUSAL,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
@ -319,6 +320,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
@ -3033,8 +3035,8 @@ static void llm_load_hparams(
case LLM_ARCH_BERT:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
hparams.causal_attn = false;
switch (hparams.n_embd) {
case 384: // MiniLM
@ -7601,9 +7603,11 @@ static int llama_decode_internal(
if (!lctx.embedding.empty()) {
auto & embedding_out = lctx.embedding;
const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
embedding_out.resize(n_embd);
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), 0, n_embd*sizeof(float));
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float));
ggml_backend_synchronize(embeddings_backend);
}