Consolidate Phi model conversion handling in convert-hf-to-gguf.py

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
teleprint-me 2023-12-20 15:26:25 -05:00
parent e96f40bf99
commit ea6ae8d04c
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -10,7 +10,7 @@ import re
import sys import sys
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Optional, cast
import numpy as np import numpy as np
import torch import torch
@ -22,7 +22,6 @@ if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf import gguf
###### MODEL DEFINITIONS ###### ###### MODEL DEFINITIONS ######
class SentencePieceTokenTypes(IntEnum): class SentencePieceTokenTypes(IntEnum):
@ -183,7 +182,7 @@ class Model:
if model_architecture == "MixtralForCausalLM": if model_architecture == "MixtralForCausalLM":
return MixtralModel return MixtralModel
if model_architecture == "PhiForCausalLM": if model_architecture == "PhiForCausalLM":
return Phi2Model return PhiModel
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool:
@ -224,7 +223,7 @@ class Model:
if arch == "MixtralForCausalLM": if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM": if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2 return gguf.MODEL_ARCH.PHI
raise NotImplementedError(f'Architecture "{arch}" not supported!') raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -985,11 +984,11 @@ class QwenModel(Model):
self.gguf_writer.add_tensor(new_name, data) self.gguf_writer.add_tensor(new_name, data)
class Phi2Model(Model): class PhiModel(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["n_layer"] block_count = self.hparams["n_layer"]
self.gguf_writer.add_name("Phi2") self.gguf_writer.add_name("Phi")
self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])