XLMRoberta support

This commit is contained in:
Oliver Ye 2024-07-22 17:26:11 -07:00
parent 081fe431aa
commit 60e36747ca
5 changed files with 142 additions and 17 deletions

View file

@ -585,6 +585,9 @@ class Model:
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
res = "jina-v2-code"
if chkhsh == "a81863d07e75497e2194eb1a1574d5e5cd4d5f85a87a0728b922bf2bed6fb327":
# ref: https://huggingface.co/intfloat/multilingual-e5-base
res = "multilingual-e5-base"
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
@ -2447,6 +2450,68 @@ class BertModel(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("XLMRobertaModel")
class XLMRobertaModel(Model):
model_arch = gguf.MODEL_ARCH.XLMROBERTA
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vocab_size = None
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_causal_attention(False)
# get pooling path
pooling_path = None
module_path = self.dir_model / "modules.json"
if module_path.is_file():
with open(module_path, encoding="utf-8") as f:
modules = json.load(f)
for mod in modules:
if mod["type"] == "sentence_transformers.models.Pooling":
pooling_path = mod["path"]
break
# get pooling type
if pooling_path is not None:
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
pooling = json.load(f)
if pooling["pooling_mode_mean_tokens"]:
pooling_type = gguf.PoolingType.MEAN
elif pooling["pooling_mode_cls_token"]:
pooling_type = gguf.PoolingType.CLS
else:
raise NotImplementedError("Only MEAN and CLS pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type)
def set_vocab(self):
tokens, toktypes, tokpre = self.get_vocab_base()
self.vocab_size = len(tokens)
self.gguf_writer.add_token_type_count(int(self.hparams['type_vocab_size']))
# add vocab to gguf
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_add_eos_token(True)
# handle special tokens
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# we are only using BERT for embeddings so we don't need the pooling layer
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
return [] # we don't need these
return [(self.map_tensor_name(name), data_torch)]
@Model.register("NomicBertModel")
class NomicBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.NOMIC_BERT

View file

@ -86,6 +86,7 @@ models = [
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
{"name": "multilingual-e5-base", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/intfloat/multilingual-e5-base", },
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
@ -144,7 +145,7 @@ for model in models:
name = model["name"]
tokt = model["tokt"]
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
if (tokt == TOKENIZER_TYPE.SPM and name != "multilingual-e5-base") or tokt == TOKENIZER_TYPE.UGM:
continue
# Skip if the tokenizer folder does not exist or there are other download issues previously

View file

@ -217,6 +217,7 @@ class MODEL_ARCH(IntEnum):
BITNET = auto()
T5 = auto()
JAIS = auto()
XLMROBERTA = auto()
class MODEL_TENSOR(IntEnum):
@ -344,6 +345,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.XLMROBERTA: "xlm-roberta",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -538,6 +540,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.XLMROBERTA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.NOMIC_BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,

View file

@ -95,6 +95,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_E5 = 23,
};
// note: these values should be synchronized with ggml_rope

View file

@ -241,6 +241,7 @@ enum llm_arch {
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_JAIS,
LLM_ARCH_XLMROBERTA,
LLM_ARCH_UNKNOWN,
};
@ -285,6 +286,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_XLMROBERTA, "xlm-roberta" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -765,6 +767,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_XLMROBERTA,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_NOMIC_BERT,
{
@ -4846,6 +4865,7 @@ static void llm_load_hparams(
hparams.f_max_alibi_bias = 8.0f;
} break;
case LLM_ARCH_BERT:
case LLM_ARCH_XLMROBERTA:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
@ -5532,7 +5552,12 @@ static void llm_load_vocab(
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
} else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
if (tokenizer_pre == "multilingual-e5-base") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_E5;
}
else{
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
}
vocab.tokenizer_add_space_prefix = true;
vocab.tokenizer_clean_spaces = false;
vocab.tokenizer_add_bos = true;
@ -6398,12 +6423,12 @@ static bool llm_load_tensors(
} break;
case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_XLMROBERTA:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
if (model.arch == LLM_ARCH_BERT) {
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train});
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train});
}
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
@ -6415,7 +6440,7 @@ static bool llm_load_tensors(
auto & layer = model.layers[i];
if (model.arch == LLM_ARCH_BERT) {
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
@ -6436,10 +6461,10 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
if (model.arch == LLM_ARCH_BERT) {
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
} else {
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
}
@ -9769,7 +9794,7 @@ struct llm_build_context {
// token types are hardcoded to zero ("Sentence A")
struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
inpL = ggml_add(ctx0, inpL, type_row0);
if (model.arch == LLM_ARCH_BERT) {
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
}
cb(inpL, "inp_embd", -1);
@ -9790,8 +9815,8 @@ struct llm_build_context {
struct ggml_tensor * Vcur;
// self-attention
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq);
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2 || model.arch == LLM_ARCH_XLMROBERTA) {
Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
cb(Qcur, "Qcur", il);
if (model.layers[il].attn_q_norm) {
@ -9898,7 +9923,7 @@ struct llm_build_context {
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
if (model.arch == LLM_ARCH_BERT) {
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_XLMROBERTA) {
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
NULL, NULL, NULL,
@ -13856,6 +13881,7 @@ static struct ggml_cgraph * llama_build_graph(
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_XLMROBERTA:
{
result = llm.build_bert();
} break;
@ -14065,8 +14091,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (batch.pos && lctx.inp_pos) {
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
if (lctx.model.arch != LLM_ARCH_XLMROBERTA) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
} else {
std::vector<llama_pos> position_ids(n_tokens, 0);
for (int64_t i = 0; i < n_tokens; ++i) {
if (i == 0 || batch.pos[i] != 0) {
position_ids[i] = batch.pos[i] + lctx.model.vocab.special_pad_id + 1;
}
}
ggml_backend_tensor_set(lctx.inp_pos, position_ids.data(), 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
}
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
@ -15306,7 +15341,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
}
// Try to fall back to just the byte as a string
const char buf2[2] = { (char)ch, 0 };
return vocab.token_to_id.at(buf2);
return vocab.token_to_id.find(buf2) != vocab.token_to_id.end() ? token->second:vocab.special_unk_id;
}
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_BPE: {
@ -16471,6 +16506,11 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
llm_tokenizer_spm tokenizer(vocab);
//Temporary workaround for SPM Preprocessor
if(vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_E5){
std::regex ws_re("\\s+");
raw_text = std::regex_replace(raw_text, ws_re, " ");
}
llama_escape_whitespace(raw_text);
tokenizer.tokenize(raw_text, output);
is_prev_special = false;
@ -19480,6 +19520,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
case LLM_ARCH_XLMROBERTA:
case LLM_ARCH_CODESHELL:
return LLAMA_ROPE_TYPE_NEOX;