llama : add phi3 128K model support (#7225)
* add phi3 128k support in convert-hf-to-gguf * add phi3 128k support in cuda * address build warnings on llama.cpp * adjust index value in cuda long rope freq factors * add long rope support in ggml cpu backend * make freq factors only depend on ctx size * remove unused rope scaling type 'su' frin gguf converter * fix flint warnings on convert-hf-to-gguf.py * set to the short freq factor when context size is small than trained context size * add one line of comments * metal : support rope freq_factors * ggml : update ggml_rope_ext API to support freq. factors * backends : add dev messages to support rope freq. factors * minor : style * tests : update to use new rope API * backends : fix pragma semicolons * minor : cleanup * llama : move rope factors from KV header to tensors * llama : remove tmp assert * cuda : fix compile warning * convert : read/write n_head_kv * llama : fix uninitialized tensors --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6369bf0433
commit
201cc11afa
15 changed files with 484 additions and 233 deletions
|
@ -14,6 +14,7 @@ from pathlib import Path
|
|||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
@ -1784,23 +1785,59 @@ class Phi3MiniModel(Model):
|
|||
def set_gguf_parameters(self):
|
||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
||||
|
||||
rot_pct = 1.0
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
rms_eps = self.find_hparam(["rms_norm_eps"])
|
||||
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
|
||||
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
|
||||
rope_dims = n_embd // n_head
|
||||
|
||||
self.gguf_writer.add_name("Phi3")
|
||||
self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
|
||||
|
||||
self.gguf_writer.add_context_length(max_pos_embds)
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
self.gguf_writer.add_feed_forward_length(8192)
|
||||
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
|
||||
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dims)
|
||||
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
# write rope scaling for long context (128k) model
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if (rope_scaling is None):
|
||||
return
|
||||
|
||||
scale = max_pos_embds / orig_max_pos_embds
|
||||
|
||||
rope_scaling_type = rope_scaling.get('type', '').lower()
|
||||
if len(rope_scaling_type) == 0:
|
||||
raise KeyError('Missing the required key rope_scaling.type')
|
||||
|
||||
if rope_scaling_type == 'su':
|
||||
attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0
|
||||
elif rope_scaling_type == 'yarn':
|
||||
attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0
|
||||
else:
|
||||
raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet')
|
||||
|
||||
self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||
|
||||
|
||||
@Model.register("PlamoForCausalLM")
|
||||
class PlamoModel(Model):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue