flake8 support

This commit is contained in:
lixiaopu 2024-01-25 08:52:44 +08:00
parent d64bb815e8
commit 154319c42d

View file

@ -575,6 +575,7 @@ class MPTModel(Model):
if new_name == "token_embd.weight": if new_name == "token_embd.weight":
self.gguf_writer.add_tensor("output.weight", data) self.gguf_writer.add_tensor("output.weight", data)
class OrionModel(Model): class OrionModel(Model):
def set_vocab(self): def set_vocab(self):
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()
@ -612,9 +613,7 @@ class OrionModel(Model):
# Collect tensors from generator object # Collect tensors from generator object
model_kv = dict(self.get_tensors()) model_kv = dict(self.get_tensors())
block_count = self.hparams["num_hidden_layers"] block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
for name, data_torch in model_kv.items(): for name, data_torch in model_kv.items():
# we don't need these # we don't need these
@ -653,6 +652,7 @@ class OrionModel(Model):
print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data) self.gguf_writer.add_tensor(new_name, data)
class BaichuanModel(Model): class BaichuanModel(Model):
def set_vocab(self): def set_vocab(self):
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()