put causal_attn flag in gguf
This commit is contained in:
parent
59c1829b0c
commit
5f1c21d0b6
2 changed files with 7 additions and 2 deletions
|
@ -1588,6 +1588,7 @@ class BertModel(Model):
|
||||||
self.gguf_writer.add_block_count(self.block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
|
||||||
|
self.gguf_writer.add_bool("bert.attention.causal", False)
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
|
|
|
@ -263,6 +263,7 @@ enum llm_kv {
|
||||||
LLM_KV_ATTENTION_VALUE_LENGTH,
|
LLM_KV_ATTENTION_VALUE_LENGTH,
|
||||||
LLM_KV_ATTENTION_LAYERNORM_EPS,
|
LLM_KV_ATTENTION_LAYERNORM_EPS,
|
||||||
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
|
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
|
||||||
|
LLM_KV_ATTENTION_CAUSAL,
|
||||||
|
|
||||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||||
LLM_KV_ROPE_FREQ_BASE,
|
LLM_KV_ROPE_FREQ_BASE,
|
||||||
|
@ -319,6 +320,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||||
|
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||||
|
|
||||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||||
|
@ -3033,8 +3035,8 @@ static void llm_load_hparams(
|
||||||
case LLM_ARCH_BERT:
|
case LLM_ARCH_BERT:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||||
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
||||||
hparams.causal_attn = false;
|
|
||||||
|
|
||||||
switch (hparams.n_embd) {
|
switch (hparams.n_embd) {
|
||||||
case 384: // MiniLM
|
case 384: // MiniLM
|
||||||
|
@ -7601,9 +7603,11 @@ static int llama_decode_internal(
|
||||||
if (!lctx.embedding.empty()) {
|
if (!lctx.embedding.empty()) {
|
||||||
auto & embedding_out = lctx.embedding;
|
auto & embedding_out = lctx.embedding;
|
||||||
|
|
||||||
|
const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
|
||||||
|
|
||||||
embedding_out.resize(n_embd);
|
embedding_out.resize(n_embd);
|
||||||
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
||||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), 0, n_embd*sizeof(float));
|
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float));
|
||||||
ggml_backend_synchronize(embeddings_backend);
|
ggml_backend_synchronize(embeddings_backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue