llama : allow for user specified embedding pooling type (#5849)
* allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
87c2e8b279
commit
475df1d6cf
5 changed files with 60 additions and 29 deletions
|
@ -1644,16 +1644,17 @@ class BertModel(Model):
|
|||
self.gguf_writer.add_causal_attention(False)
|
||||
|
||||
# get pooling path
|
||||
with open(self.dir_model / "modules.json", encoding="utf-8") as f:
|
||||
modules = json.load(f)
|
||||
pooling_path = None
|
||||
for mod in modules:
|
||||
if mod["type"] == "sentence_transformers.models.Pooling":
|
||||
pooling_path = mod["path"]
|
||||
break
|
||||
module_path = self.dir_model / "modules.json"
|
||||
if module_path.is_file():
|
||||
with open(module_path, encoding="utf-8") as f:
|
||||
modules = json.load(f)
|
||||
for mod in modules:
|
||||
if mod["type"] == "sentence_transformers.models.Pooling":
|
||||
pooling_path = mod["path"]
|
||||
break
|
||||
|
||||
# get pooling type
|
||||
pooling_type = gguf.PoolingType.NONE
|
||||
if pooling_path is not None:
|
||||
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
|
||||
pooling = json.load(f)
|
||||
|
@ -1663,8 +1664,7 @@ class BertModel(Model):
|
|||
pooling_type = gguf.PoolingType.CLS
|
||||
else:
|
||||
raise NotImplementedError("Only MEAN and CLS pooling types supported")
|
||||
|
||||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
|
||||
def set_vocab(self):
|
||||
path = self.dir_model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue