Add llama 3.1 rope scaling factors to llama conversion and inference

This commit generates the rope factors on conversion and adds them to the resulting model as a tensor. At inference time, these factors are passed to the `ggml_rope_ext` rope oepration, improving results for context windows above 8192
This commit is contained in:
jmorganca 2024-07-24 15:46:47 -04:00
parent 01245f5b16
commit e6bacb405a
2 changed files with 41 additions and 2 deletions

View file

@ -1514,6 +1514,35 @@ class LlamaModel(Model):
if self.hparams.get("vocab_size", 32000) == 49152: if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False) self.gguf_writer.add_add_bos_token(False)
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = hparams.get("rope_theta", 10000.0)
dim = int((hparams["hidden_size"] // hparams["num_attention_heads"]) * hparams.get("partial_rotary_embeddings", 1.0))
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
factor = rope_scaling.get("factor", 8.0)
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
old_context_len = hparams.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
rope_factors = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
rope_factors.append(1)
elif wavelen > low_freq_wavelen:
rope_factors.append(factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
self.gguf_writer.add_rope_scaling_attn_factors(1.0)
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FREQS] + ".weight", np.array(rope_factors, dtype=np.float32))
@staticmethod @staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None): def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv: if n_head_kv is not None and n_head != n_head_kv:

View file

@ -2455,6 +2455,7 @@ struct llama_layer {
// long rope factors // long rope factors
struct ggml_tensor * rope_long = nullptr; struct ggml_tensor * rope_long = nullptr;
struct ggml_tensor * rope_short = nullptr; struct ggml_tensor * rope_short = nullptr;
struct ggml_tensor * rope_freqs = nullptr;
// bitnet scale // bitnet scale
struct ggml_tensor * wq_scale; struct ggml_tensor * wq_scale;
@ -6054,6 +6055,8 @@ static bool llm_load_tensors(
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), { n_embd/n_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
if (n_expert == 0) { if (n_expert == 0) {
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
@ -8531,6 +8534,10 @@ struct llm_build_context {
// choose long/short freq factors based on the context size // choose long/short freq factors based on the context size
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
if (model.layers[il].rope_freqs != nullptr) {
return model.layers[il].rope_freqs;
}
if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) { if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
return model.layers[il].rope_long; return model.layers[il].rope_long;
} }
@ -8725,6 +8732,9 @@ struct llm_build_context {
// self-attention // self-attention
{ {
// rope freq factors for llama3; may return nullptr for llama2 and other models
struct ggml_tensor * rope_factors = build_rope_factors(il);
// compute Q and K and RoPE them // compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
@ -8748,14 +8758,14 @@ struct llm_build_context {
} }
Qcur = ggml_rope_ext( Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext( Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );