Consolidate Phi model conversion handling in convert-hf-to-gguf.py

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
teleprint-me 2023-12-20 15:26:25 -05:00
parent e96f40bf99
commit ea6ae8d04c
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -10,7 +10,7 @@ import re
import sys
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Optional, cast
import numpy as np
import torch
@ -22,7 +22,6 @@ if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
###### MODEL DEFINITIONS ######
class SentencePieceTokenTypes(IntEnum):
@ -183,7 +182,7 @@ class Model:
if model_architecture == "MixtralForCausalLM":
return MixtralModel
if model_architecture == "PhiForCausalLM":
return Phi2Model
return PhiModel
return Model
def _is_model_safetensors(self) -> bool:
@ -224,7 +223,7 @@ class Model:
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2
return gguf.MODEL_ARCH.PHI
raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -985,11 +984,11 @@ class QwenModel(Model):
self.gguf_writer.add_tensor(new_name, data)
class Phi2Model(Model):
class PhiModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_name("Phi2")
self.gguf_writer.add_name("Phi")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])