convert script fixes

This commit is contained in:
Douglas Hanley 2024-02-15 10:17:24 -06:00
parent d2b77cce91
commit 34aa045de4
2 changed files with 4 additions and 5 deletions

View file

@ -1652,7 +1652,7 @@ class BertModel(Model):
self.gguf_writer.add_causal_attention(False) self.gguf_writer.add_causal_attention(False)
# get pooling path # get pooling path
with open(self.dir_model / "modules.json", "r", encoding="utf-8") as f: with open(self.dir_model / "modules.json", encoding="utf-8") as f:
modules = json.load(f) modules = json.load(f)
pooling_path = None pooling_path = None
for mod in modules: for mod in modules:
@ -1663,15 +1663,14 @@ class BertModel(Model):
# get pooling type # get pooling type
pooling_type = gguf.PoolingType.NONE pooling_type = gguf.PoolingType.NONE
if pooling_path is not None: if pooling_path is not None:
with open(self.dir_model / pooling_path / "config.json", "r", encoding="utf-8") as f: with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
pooling = json.load(f) pooling = json.load(f)
if pooling["pooling_mode_mean_tokens"]: if pooling["pooling_mode_mean_tokens"]:
pooling_type = gguf.PoolingType.MEAN pooling_type = gguf.PoolingType.MEAN
elif pooling["pooling_mode_cls_token"]: elif pooling["pooling_mode_cls_token"]:
pooling_type = gguf.PoolingType.CLS pooling_type = gguf.PoolingType.CLS
else: else:
print("Only MEAN and CLS pooling types supported") raise NotImplementedError("Only MEAN and CLS pooling types supported")
sys.exit(1)
self.gguf_writer.add_pooling_type(pooling_type.value) self.gguf_writer.add_pooling_type(pooling_type.value)

View file

@ -559,7 +559,7 @@ class RopeScalingType(Enum):
YARN = 'yarn' YARN = 'yarn'
class PoolingType(Enum): class PoolingType(IntEnum):
NONE = 0 NONE = 0
MEAN = 1 MEAN = 1
CLS = 2 CLS = 2