llama : add reranking support (#9510)
* py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
1b2f992cd2
commit
f4d2b8846a
18 changed files with 602 additions and 56 deletions
|
@ -291,8 +291,13 @@ class Model:
|
|||
bid = int(part)
|
||||
break
|
||||
|
||||
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
|
||||
data: np.ndarray # type hint
|
||||
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
|
||||
data = data_torch.squeeze().numpy()
|
||||
|
||||
# if data ends up empty, it means data_torch was a scalar tensor -> restore
|
||||
if len(data.shape) == 0:
|
||||
data = data_torch.numpy()
|
||||
|
||||
n_dims = len(data.shape)
|
||||
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
|
@ -592,6 +597,9 @@ class Model:
|
|||
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
|
||||
# ref: https://huggingface.co/databricks/dbrx-base
|
||||
res = "dbrx"
|
||||
if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448":
|
||||
# ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||
res = "jina-v1-en"
|
||||
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
|
||||
res = "jina-v2-en"
|
||||
|
@ -2601,7 +2609,7 @@ class NomicBertModel(BertModel):
|
|||
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
|
||||
|
||||
|
||||
@Model.register("XLMRobertaModel")
|
||||
@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
|
||||
class XLMRobertaModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.BERT
|
||||
|
||||
|
@ -2699,6 +2707,11 @@ class XLMRobertaModel(BertModel):
|
|||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# if name starts with "roberta.", remove the prefix
|
||||
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
|
||||
if name.startswith("roberta."):
|
||||
name = name[8:]
|
||||
|
||||
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
|
||||
if name == "embeddings.position_embeddings.weight":
|
||||
if self._position_offset is not None:
|
||||
|
@ -3110,6 +3123,14 @@ class JinaBertV2Model(BertModel):
|
|||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# if name starts with "bert.", remove the prefix
|
||||
# e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||
if name.startswith("bert."):
|
||||
name = name[5:]
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@Model.register("OpenELMForCausalLM")
|
||||
class OpenELMModel(Model):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue