XLMRoberta support
This commit is contained in:
parent
081fe431aa
commit
60e36747ca
5 changed files with 142 additions and 17 deletions
|
@ -585,6 +585,9 @@ class Model:
|
|||
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
|
||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
|
||||
res = "jina-v2-code"
|
||||
if chkhsh == "a81863d07e75497e2194eb1a1574d5e5cd4d5f85a87a0728b922bf2bed6fb327":
|
||||
# ref: https://huggingface.co/intfloat/multilingual-e5-base
|
||||
res = "multilingual-e5-base"
|
||||
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
|
||||
res = "chatglm-bpe"
|
||||
|
@ -2447,6 +2450,68 @@ class BertModel(Model):
|
|||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("XLMRobertaModel")
|
||||
class XLMRobertaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.XLMROBERTA
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.vocab_size = None
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
|
||||
# get pooling path
|
||||
pooling_path = None
|
||||
module_path = self.dir_model / "modules.json"
|
||||
if module_path.is_file():
|
||||
with open(module_path, encoding="utf-8") as f:
|
||||
modules = json.load(f)
|
||||
for mod in modules:
|
||||
if mod["type"] == "sentence_transformers.models.Pooling":
|
||||
pooling_path = mod["path"]
|
||||
break
|
||||
|
||||
# get pooling type
|
||||
if pooling_path is not None:
|
||||
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
|
||||
pooling = json.load(f)
|
||||
if pooling["pooling_mode_mean_tokens"]:
|
||||
pooling_type = gguf.PoolingType.MEAN
|
||||
elif pooling["pooling_mode_cls_token"]:
|
||||
pooling_type = gguf.PoolingType.CLS
|
||||
else:
|
||||
raise NotImplementedError("Only MEAN and CLS pooling types supported")
|
||||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
|
||||
def set_vocab(self):
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
self.vocab_size = len(tokens)
|
||||
|
||||
self.gguf_writer.add_token_type_count(int(self.hparams['type_vocab_size']))
|
||||
|
||||
# add vocab to gguf
|
||||
self.gguf_writer.add_tokenizer_model("llama")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
# handle special tokens
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# we are only using BERT for embeddings so we don't need the pooling layer
|
||||
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
|
||||
return [] # we don't need these
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("NomicBertModel")
|
||||
class NomicBertModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.NOMIC_BERT
|
||||
|
|
|
@ -86,6 +86,7 @@ models = [
|
|||
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
|
||||
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
|
||||
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
|
||||
{"name": "multilingual-e5-base", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/intfloat/multilingual-e5-base", },
|
||||
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
|
||||
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
|
||||
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
|
||||
|
@ -144,7 +145,7 @@ for model in models:
|
|||
name = model["name"]
|
||||
tokt = model["tokt"]
|
||||
|
||||
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
|
||||
if (tokt == TOKENIZER_TYPE.SPM and name != "multilingual-e5-base") or tokt == TOKENIZER_TYPE.UGM:
|
||||
continue
|
||||
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
|
|
|
@ -217,6 +217,7 @@ class MODEL_ARCH(IntEnum):
|
|||
BITNET = auto()
|
||||
T5 = auto()
|
||||
JAIS = auto()
|
||||
XLMROBERTA = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
|
@ -344,6 +345,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.BITNET: "bitnet",
|
||||
MODEL_ARCH.T5: "t5",
|
||||
MODEL_ARCH.JAIS: "jais",
|
||||
MODEL_ARCH.XLMROBERTA: "xlm-roberta",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
|
@ -538,6 +540,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.XLMROBERTA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.TOKEN_TYPES,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.NOMIC_BERT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
|
|
|
@ -95,6 +95,7 @@ extern "C" {
|
|||
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
||||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||
LLAMA_VOCAB_PRE_TYPE_E5 = 23,
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
|
|
|
@ -241,6 +241,7 @@ enum llm_arch {
|
|||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_T5,
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_XLMROBERTA,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -285,6 +286,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_BITNET, "bitnet" },
|
||||
{ LLM_ARCH_T5, "t5" },
|
||||
{ LLM_ARCH_JAIS, "jais" },
|
||||
{ LLM_ARCH_XLMROBERTA, "xlm-roberta" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
@ -765,6 +767,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_XLMROBERTA,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
{
|
||||
|
@ -4846,6 +4865,7 @@ static void llm_load_hparams(
|
|||
hparams.f_max_alibi_bias = 8.0f;
|
||||
} break;
|
||||
case LLM_ARCH_BERT:
|
||||
case LLM_ARCH_XLMROBERTA:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
|
@ -5532,7 +5552,12 @@ static void llm_load_vocab(
|
|||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
} else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
if (tokenizer_pre == "multilingual-e5-base") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_E5;
|
||||
}
|
||||
else{
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
}
|
||||
vocab.tokenizer_add_space_prefix = true;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
vocab.tokenizer_add_bos = true;
|
||||
|
@ -6398,12 +6423,12 @@ static bool llm_load_tensors(
|
|||
} break;
|
||||
case LLM_ARCH_BERT:
|
||||
case LLM_ARCH_NOMIC_BERT:
|
||||
case LLM_ARCH_XLMROBERTA:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
|
||||
|
||||
if (model.arch == LLM_ARCH_BERT) {
|
||||
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train});
|
||||
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
|
||||
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train});
|
||||
}
|
||||
|
||||
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
|
||||
|
@ -6415,7 +6440,7 @@ static bool llm_load_tensors(
|
|||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
if (model.arch == LLM_ARCH_BERT) {
|
||||
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
|
||||
|
||||
|
@ -6436,10 +6461,10 @@ static bool llm_load_tensors(
|
|||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
|
||||
if (model.arch == LLM_ARCH_BERT) {
|
||||
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
||||
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
|
||||
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
|
||||
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
||||
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
|
||||
} else {
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
|
@ -9769,7 +9794,7 @@ struct llm_build_context {
|
|||
// token types are hardcoded to zero ("Sentence A")
|
||||
struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
|
||||
inpL = ggml_add(ctx0, inpL, type_row0);
|
||||
if (model.arch == LLM_ARCH_BERT) {
|
||||
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
|
||||
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
|
||||
}
|
||||
cb(inpL, "inp_embd", -1);
|
||||
|
@ -9790,8 +9815,8 @@ struct llm_build_context {
|
|||
struct ggml_tensor * Vcur;
|
||||
|
||||
// self-attention
|
||||
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
|
||||
Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq);
|
||||
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2 || model.arch == LLM_ARCH_XLMROBERTA) {
|
||||
Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
if (model.layers[il].attn_q_norm) {
|
||||
|
@ -9898,7 +9923,7 @@ struct llm_build_context {
|
|||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
if (model.arch == LLM_ARCH_BERT) {
|
||||
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
NULL, NULL, NULL,
|
||||
|
@ -13856,6 +13881,7 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
case LLM_ARCH_BERT:
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
case LLM_ARCH_NOMIC_BERT:
|
||||
case LLM_ARCH_XLMROBERTA:
|
||||
{
|
||||
result = llm.build_bert();
|
||||
} break;
|
||||
|
@ -14065,8 +14091,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
|
||||
if (batch.pos && lctx.inp_pos) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
if (lctx.model.arch != LLM_ARCH_XLMROBERTA) {
|
||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
} else {
|
||||
std::vector<llama_pos> position_ids(n_tokens, 0);
|
||||
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||
if (i == 0 || batch.pos[i] != 0) {
|
||||
position_ids[i] = batch.pos[i] + lctx.model.vocab.special_pad_id + 1;
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(lctx.inp_pos, position_ids.data(), 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
}
|
||||
|
||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
|
@ -15306,7 +15341,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
|||
}
|
||||
// Try to fall back to just the byte as a string
|
||||
const char buf2[2] = { (char)ch, 0 };
|
||||
return vocab.token_to_id.at(buf2);
|
||||
return vocab.token_to_id.find(buf2) != vocab.token_to_id.end() ? token->second:vocab.special_unk_id;
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
|
@ -16471,6 +16506,11 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||
#endif
|
||||
llm_tokenizer_spm tokenizer(vocab);
|
||||
//Temporary workaround for SPM Preprocessor
|
||||
if(vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_E5){
|
||||
std::regex ws_re("\\s+");
|
||||
raw_text = std::regex_replace(raw_text, ws_re, " ");
|
||||
}
|
||||
llama_escape_whitespace(raw_text);
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
is_prev_special = false;
|
||||
|
@ -19480,6 +19520,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
case LLM_ARCH_XLMROBERTA:
|
||||
case LLM_ARCH_CODESHELL:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue