From 60e36747caf4dbdec6a9b35e6f2878523c37bbea Mon Sep 17 00:00:00 2001 From: Oliver Ye Date: Mon, 22 Jul 2024 17:26:11 -0700 Subject: [PATCH] XLMRoberta support --- convert_hf_to_gguf.py | 65 ++++++++++++++++++++++++++++++++ convert_hf_to_gguf_update.py | 3 +- gguf-py/gguf/constants.py | 17 +++++++++ include/llama.h | 1 + src/llama.cpp | 73 ++++++++++++++++++++++++++++-------- 5 files changed, 142 insertions(+), 17 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dde4fa9c8..8fe9286cd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -585,6 +585,9 @@ class Model: if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code res = "jina-v2-code" + if chkhsh == "a81863d07e75497e2194eb1a1574d5e5cd4d5f85a87a0728b922bf2bed6fb327": + # ref: https://huggingface.co/intfloat/multilingual-e5-base + res = "multilingual-e5-base" if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": # ref: https://huggingface.co/THUDM/glm-4-9b-chat res = "chatglm-bpe" @@ -2447,6 +2450,68 @@ class BertModel(Model): return [(self.map_tensor_name(name), data_torch)] +@Model.register("XLMRobertaModel") +class XLMRobertaModel(Model): + model_arch = gguf.MODEL_ARCH.XLMROBERTA + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.vocab_size = None + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_causal_attention(False) + + # get pooling path + pooling_path = None + module_path = self.dir_model / "modules.json" + if module_path.is_file(): + with open(module_path, encoding="utf-8") as f: + modules = json.load(f) + for mod in modules: + if mod["type"] == "sentence_transformers.models.Pooling": + pooling_path = mod["path"] + break + + # get pooling type + if pooling_path is not None: + with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: + pooling = json.load(f) + if pooling["pooling_mode_mean_tokens"]: + pooling_type = gguf.PoolingType.MEAN + elif pooling["pooling_mode_cls_token"]: + pooling_type = gguf.PoolingType.CLS + else: + raise NotImplementedError("Only MEAN and CLS pooling types supported") + self.gguf_writer.add_pooling_type(pooling_type) + + def set_vocab(self): + tokens, toktypes, tokpre = self.get_vocab_base() + self.vocab_size = len(tokens) + + self.gguf_writer.add_token_type_count(int(self.hparams['type_vocab_size'])) + + # add vocab to gguf + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_eos_token(True) + + # handle special tokens + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # we are only using BERT for embeddings so we don't need the pooling layer + if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): + return [] # we don't need these + + return [(self.map_tensor_name(name), data_torch)] + + @Model.register("NomicBertModel") class NomicBertModel(BertModel): model_arch = gguf.MODEL_ARCH.NOMIC_BERT diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index d5a2d925e..3727c3435 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -86,6 +86,7 @@ models = [ {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, {"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", }, {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, + {"name": "multilingual-e5-base", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/intfloat/multilingual-e5-base", }, {"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, @@ -144,7 +145,7 @@ for model in models: name = model["name"] tokt = model["tokt"] - if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM: + if (tokt == TOKENIZER_TYPE.SPM and name != "multilingual-e5-base") or tokt == TOKENIZER_TYPE.UGM: continue # Skip if the tokenizer folder does not exist or there are other download issues previously diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index e343c2ef1..8fd508707 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -217,6 +217,7 @@ class MODEL_ARCH(IntEnum): BITNET = auto() T5 = auto() JAIS = auto() + XLMROBERTA = auto() class MODEL_TENSOR(IntEnum): @@ -344,6 +345,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.XLMROBERTA: "xlm-roberta", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -538,6 +540,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, ], + MODEL_ARCH.XLMROBERTA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.LAYER_OUT_NORM, + ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, diff --git a/include/llama.h b/include/llama.h index bf2761467..0a4c3de2e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -95,6 +95,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_E5 = 23, }; // note: these values should be synchronized with ggml_rope diff --git a/src/llama.cpp b/src/llama.cpp index 99a6d8b66..68f8d0e2d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -241,6 +241,7 @@ enum llm_arch { LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_JAIS, + LLM_ARCH_XLMROBERTA, LLM_ARCH_UNKNOWN, }; @@ -285,6 +286,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_XLMROBERTA, "xlm-roberta" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -765,6 +767,23 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_XLMROBERTA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_NOMIC_BERT, { @@ -4846,6 +4865,7 @@ static void llm_load_hparams( hparams.f_max_alibi_bias = 8.0f; } break; case LLM_ARCH_BERT: + case LLM_ARCH_XLMROBERTA: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); @@ -5532,7 +5552,12 @@ static void llm_load_vocab( throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + if (tokenizer_pre == "multilingual-e5-base") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_E5; + } + else{ + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } vocab.tokenizer_add_space_prefix = true; vocab.tokenizer_clean_spaces = false; vocab.tokenizer_add_bos = true; @@ -6398,12 +6423,12 @@ static bool llm_load_tensors( } break; case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_XLMROBERTA: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); - - if (model.arch == LLM_ARCH_BERT) { - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) { + model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); } model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); @@ -6415,7 +6440,7 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - if (model.arch == LLM_ARCH_BERT) { + if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) { layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); @@ -6436,10 +6461,10 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - if (model.arch == LLM_ARCH_BERT) { - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) { + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); } else { layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); } @@ -9769,7 +9794,7 @@ struct llm_build_context { // token types are hardcoded to zero ("Sentence A") struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); inpL = ggml_add(ctx0, inpL, type_row0); - if (model.arch == LLM_ARCH_BERT) { + if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) { inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); } cb(inpL, "inp_embd", -1); @@ -9790,8 +9815,8 @@ struct llm_build_context { struct ggml_tensor * Vcur; // self-attention - if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { - Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq); + if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2 || model.arch == LLM_ARCH_XLMROBERTA) { + Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); cb(Qcur, "Qcur", il); if (model.layers[il].attn_q_norm) { @@ -9898,7 +9923,7 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); // feed-forward network - if (model.arch == LLM_ARCH_BERT) { + if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) { cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, @@ -13856,6 +13881,7 @@ static struct ggml_cgraph * llama_build_graph( case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_XLMROBERTA: { result = llm.build_bert(); } break; @@ -14065,8 +14091,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (batch.pos && lctx.inp_pos) { const int64_t n_tokens = batch.n_tokens; - - ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); + if (lctx.model.arch != LLM_ARCH_XLMROBERTA) { + ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); + } else { + std::vector position_ids(n_tokens, 0); + for (int64_t i = 0; i < n_tokens; ++i) { + if (i == 0 || batch.pos[i] != 0) { + position_ids[i] = batch.pos[i] + lctx.model.vocab.special_pad_id + 1; + } + } + ggml_backend_tensor_set(lctx.inp_pos, position_ids.data(), 0, n_tokens*ggml_element_size(lctx.inp_pos)); + } } if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { @@ -15306,7 +15341,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { } // Try to fall back to just the byte as a string const char buf2[2] = { (char)ch, 0 }; - return vocab.token_to_id.at(buf2); + return vocab.token_to_id.find(buf2) != vocab.token_to_id.end() ? token->second:vocab.special_unk_id; } case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_BPE: { @@ -16471,6 +16506,11 @@ static std::vector llama_tokenize_internal(const llama_vocab & LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif llm_tokenizer_spm tokenizer(vocab); + //Temporary workaround for SPM Preprocessor + if(vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_E5){ + std::regex ws_re("\\s+"); + raw_text = std::regex_replace(raw_text, ws_re, " "); + } llama_escape_whitespace(raw_text); tokenizer.tokenize(raw_text, output); is_prev_special = false; @@ -19480,6 +19520,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: + case LLM_ARCH_XLMROBERTA: case LLM_ARCH_CODESHELL: return LLAMA_ROPE_TYPE_NEOX;