convert script fixes
This commit is contained in:
parent
d2b77cce91
commit
34aa045de4
2 changed files with 4 additions and 5 deletions
|
@ -1652,7 +1652,7 @@ class BertModel(Model):
|
|||
self.gguf_writer.add_causal_attention(False)
|
||||
|
||||
# get pooling path
|
||||
with open(self.dir_model / "modules.json", "r", encoding="utf-8") as f:
|
||||
with open(self.dir_model / "modules.json", encoding="utf-8") as f:
|
||||
modules = json.load(f)
|
||||
pooling_path = None
|
||||
for mod in modules:
|
||||
|
@ -1663,15 +1663,14 @@ class BertModel(Model):
|
|||
# get pooling type
|
||||
pooling_type = gguf.PoolingType.NONE
|
||||
if pooling_path is not None:
|
||||
with open(self.dir_model / pooling_path / "config.json", "r", encoding="utf-8") as f:
|
||||
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
|
||||
pooling = json.load(f)
|
||||
if pooling["pooling_mode_mean_tokens"]:
|
||||
pooling_type = gguf.PoolingType.MEAN
|
||||
elif pooling["pooling_mode_cls_token"]:
|
||||
pooling_type = gguf.PoolingType.CLS
|
||||
else:
|
||||
print("Only MEAN and CLS pooling types supported")
|
||||
sys.exit(1)
|
||||
raise NotImplementedError("Only MEAN and CLS pooling types supported")
|
||||
|
||||
self.gguf_writer.add_pooling_type(pooling_type.value)
|
||||
|
||||
|
|
|
@ -559,7 +559,7 @@ class RopeScalingType(Enum):
|
|||
YARN = 'yarn'
|
||||
|
||||
|
||||
class PoolingType(Enum):
|
||||
class PoolingType(IntEnum):
|
||||
NONE = 0
|
||||
MEAN = 1
|
||||
CLS = 2
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue