convert script fixes

This commit is contained in:
Douglas Hanley 2024-02-15 10:17:24 -06:00
parent d2b77cce91
commit 34aa045de4
2 changed files with 4 additions and 5 deletions

View file

@ -1652,7 +1652,7 @@ class BertModel(Model):
self.gguf_writer.add_causal_attention(False)
# get pooling path
with open(self.dir_model / "modules.json", "r", encoding="utf-8") as f:
with open(self.dir_model / "modules.json", encoding="utf-8") as f:
modules = json.load(f)
pooling_path = None
for mod in modules:
@ -1663,15 +1663,14 @@ class BertModel(Model):
# get pooling type
pooling_type = gguf.PoolingType.NONE
if pooling_path is not None:
with open(self.dir_model / pooling_path / "config.json", "r", encoding="utf-8") as f:
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
pooling = json.load(f)
if pooling["pooling_mode_mean_tokens"]:
pooling_type = gguf.PoolingType.MEAN
elif pooling["pooling_mode_cls_token"]:
pooling_type = gguf.PoolingType.CLS
else:
print("Only MEAN and CLS pooling types supported")
sys.exit(1)
raise NotImplementedError("Only MEAN and CLS pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type.value)

View file

@ -559,7 +559,7 @@ class RopeScalingType(Enum):
YARN = 'yarn'
class PoolingType(Enum):
class PoolingType(IntEnum):
NONE = 0
MEAN = 1
CLS = 2