llama : move rope factors from KV header to tensors
This commit is contained in:
parent
d93b5cad0a
commit
600896b882
4 changed files with 46 additions and 73 deletions
|
@ -1834,8 +1834,8 @@ class Phi3MiniModel(Model):
|
|||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||
|
||||
self.gguf_writer.add_rope_scaling_freq_long_factors(long_factors)
|
||||
self.gguf_writer.add_rope_scaling_freq_short_factors(short_factors)
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||
|
||||
|
||||
@Model.register("PlamoForCausalLM")
|
||||
|
|
|
@ -61,8 +61,6 @@ class Keys:
|
|||
FREQ_BASE = "{arch}.rope.freq_base"
|
||||
SCALING_TYPE = "{arch}.rope.scaling.type"
|
||||
SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
||||
SCALING_LONG_FACTORS = "{arch}.rope.scaling.freq_long_factors"
|
||||
SCALING_SHORT_FACTORS = "{arch}.rope.scaling.freq_short_factors"
|
||||
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
|
||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
|
@ -151,6 +149,8 @@ class MODEL_TENSOR(IntEnum):
|
|||
OUTPUT = auto()
|
||||
OUTPUT_NORM = auto()
|
||||
ROPE_FREQS = auto()
|
||||
ROPE_FACTORS_LONG = auto()
|
||||
ROPE_FACTORS_SHORT = auto()
|
||||
ATTN_Q = auto()
|
||||
ATTN_K = auto()
|
||||
ATTN_V = auto()
|
||||
|
@ -228,6 +228,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
||||
MODEL_TENSOR.OUTPUT: "output",
|
||||
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
|
||||
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
|
||||
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
|
||||
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
|
||||
|
|
|
@ -433,12 +433,6 @@ class GGUFWriter:
|
|||
def add_rope_scaling_factor(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_freq_long_factors(self, value: Sequence[float]) -> None:
|
||||
self.add_array(Keys.Rope.SCALING_LONG_FACTORS.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_freq_short_factors(self, value: Sequence[float]) -> None:
|
||||
self.add_array(Keys.Rope.SCALING_SHORT_FACTORS.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
|
||||
|
||||
|
|
79
llama.cpp
79
llama.cpp
|
@ -304,8 +304,6 @@ enum llm_kv {
|
|||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_LONG_FACTORS,
|
||||
LLM_KV_ROPE_SCALING_SHORT_FACTORS,
|
||||
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||
|
@ -384,8 +382,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
|
||||
{ LLM_KV_ROPE_SCALING_LONG_FACTORS, "%s.rope.scaling.freq_long_factors" },
|
||||
{ LLM_KV_ROPE_SCALING_SHORT_FACTORS, "%s.rope.scaling.freq_short_factors" },
|
||||
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
|
||||
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
||||
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
||||
|
@ -442,6 +438,8 @@ enum llm_tensor {
|
|||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_ROPE_FREQS,
|
||||
LLM_TENSOR_ROPE_FACTORS_LONG,
|
||||
LLM_TENSOR_ROPE_FACTORS_SHORT,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_K,
|
||||
LLM_TENSOR_ATTN_V,
|
||||
|
@ -812,6 +810,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
|
@ -1756,14 +1756,11 @@ struct llama_hparams {
|
|||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_yarn_orig_ctx;
|
||||
|
||||
std::vector<float> rope_long_factors;
|
||||
std::vector<float> rope_short_factors;
|
||||
float rope_attn_factor = 1.0f;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
uint32_t ssm_d_inner = 0;
|
||||
|
@ -1799,10 +1796,6 @@ struct llama_hparams {
|
|||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
||||
|
||||
if (this->rope_long_factors != other.rope_long_factors) return true;
|
||||
if (this->rope_short_factors != other.rope_short_factors) return true;
|
||||
if (this->rope_attn_factor != other.rope_attn_factor) return true;
|
||||
|
||||
if (this->ssm_d_conv != other.ssm_d_conv) return true;
|
||||
if (this->ssm_d_inner != other.ssm_d_inner) return true;
|
||||
if (this->ssm_d_state != other.ssm_d_state) return true;
|
||||
|
@ -1812,6 +1805,7 @@ struct llama_hparams {
|
|||
|
||||
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
|
||||
if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
|
||||
if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true;
|
||||
if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
|
||||
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
|
||||
|
||||
|
@ -2117,6 +2111,10 @@ struct llama_model {
|
|||
struct ggml_tensor * output;
|
||||
struct ggml_tensor * output_b;
|
||||
|
||||
// long rope factors
|
||||
struct ggml_tensor * rope_long;
|
||||
struct ggml_tensor * rope_short;
|
||||
|
||||
std::vector<llama_layer> layers;
|
||||
|
||||
llama_split_mode split_mode;
|
||||
|
@ -2260,8 +2258,6 @@ struct llama_context {
|
|||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
|
||||
|
||||
struct ggml_tensor * freq_factors = nullptr; // F32 [kv_size / 2]
|
||||
|
||||
// control vectors
|
||||
struct llama_control_vector cvec;
|
||||
};
|
||||
|
@ -3898,12 +3894,6 @@ static void llm_load_hparams(
|
|||
}
|
||||
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
|
||||
|
||||
ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false);
|
||||
ml.get_arr(LLM_KV_ROPE_SCALING_SHORT_FACTORS, hparams.rope_short_factors, false);
|
||||
|
||||
GGML_ASSERT(hparams.rope_long_factors.size() == 0 || hparams.rope_long_factors.size() == hparams.n_embd / hparams.n_head / 2);
|
||||
GGML_ASSERT(hparams.rope_long_factors.size() == hparams.rope_short_factors.size());
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
|
||||
|
||||
// sanity check for n_rot (optional)
|
||||
|
@ -4937,6 +4927,7 @@ static bool llm_load_tensors(
|
|||
// create tensors for the weights
|
||||
{
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_head = n_embd / hparams.n_head;
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const int64_t n_embd_gqa = n_embd_v_gqa;
|
||||
|
@ -5648,6 +5639,9 @@ static bool llm_load_tensors(
|
|||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
|
||||
|
||||
model.rope_long = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, false);
|
||||
model.rope_short = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, false);
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
|
||||
|
@ -6878,7 +6872,7 @@ struct llm_build_context {
|
|||
cb(lctx.inp_K_shift, "K_shift", -1);
|
||||
ggml_set_input(lctx.inp_K_shift);
|
||||
|
||||
lctx.freq_factors = build_freq_factors();
|
||||
struct ggml_tensor * rope_factors = build_rope_factors();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * tmp =
|
||||
|
@ -6889,7 +6883,7 @@ struct llm_build_context {
|
|||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
0),
|
||||
lctx.inp_K_shift, lctx.freq_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
cb(tmp, "K_shifted", il);
|
||||
|
@ -6994,17 +6988,15 @@ struct llm_build_context {
|
|||
return lctx.inp_pos;
|
||||
}
|
||||
|
||||
struct ggml_tensor * build_freq_factors() {
|
||||
if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) {
|
||||
lctx.freq_factors = nullptr;
|
||||
return nullptr;
|
||||
struct ggml_tensor * build_rope_factors() {
|
||||
// choose long/short freq factors based on the context size
|
||||
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
|
||||
if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) {
|
||||
return model.rope_long;
|
||||
}
|
||||
|
||||
lctx.freq_factors = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_head_k / 2);
|
||||
cb(lctx.freq_factors, "freq_factors", -1);
|
||||
ggml_set_input(lctx.freq_factors);
|
||||
|
||||
return lctx.freq_factors;
|
||||
return model.rope_short;
|
||||
}
|
||||
|
||||
struct ggml_tensor * build_inp_out_ids() {
|
||||
|
@ -9126,7 +9118,9 @@ struct llm_build_context {
|
|||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
// rope freq factors for 128k context
|
||||
struct ggml_tensor* freq_factors = build_freq_factors();
|
||||
struct ggml_tensor * rope_factors = build_rope_factors();
|
||||
|
||||
GGML_ASSERT(rope_factors != nullptr && "rope_factors is required for phi3"); // TMP: remove me
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto residual = inpL;
|
||||
|
@ -9165,7 +9159,7 @@ struct llm_build_context {
|
|||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, freq_factors, n_rot, rope_type, 0, n_orig_ctx,
|
||||
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
@ -9174,7 +9168,7 @@ struct llm_build_context {
|
|||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, freq_factors, n_rot, rope_type, 0, n_orig_ctx,
|
||||
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
@ -10966,23 +10960,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
}
|
||||
}
|
||||
|
||||
if (lctx.freq_factors) {
|
||||
// TODO: this might have to be hparams.n_rot instead of hparams.n_embd_head_k, but maybe it does not matter
|
||||
const auto freq_dim = hparams.n_embd_head_k / 2;
|
||||
|
||||
GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim);
|
||||
GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim);
|
||||
GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim);
|
||||
|
||||
// choose long/short freq factors based on the context size
|
||||
const auto n_ctx = llama_n_ctx(&lctx);
|
||||
if (n_ctx > hparams.n_yarn_orig_ctx) {
|
||||
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
|
||||
}
|
||||
}
|
||||
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue