un-hardcode max-alibi-bias

This commit is contained in:
fmz 2024-07-01 09:26:56 -07:00
parent f42285f0e5
commit 2d4de517bb
2 changed files with 20 additions and 5 deletions

View file

@ -2938,6 +2938,7 @@ class T5Model(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("JAISLMHeadModel")
class JaisModel(Model):
model_arch = gguf.MODEL_ARCH.JAIS
@ -2954,7 +2955,7 @@ class JaisModel(Model):
self.embeddings_scale = 1.0
# note: For some JAIS flavors, output is tied to (same as) wte in original model
self.output_is_wte = False
if 'mup_embeddings_scale' in self.hparams:
if 'mup_embeddings_scale' in self.hparams:
self.output_is_wte = True # Hack (?)
self.embeddings_scale = self.hparams['mup_embeddings_scale']
elif 'embeddings_scale' in self.hparams:
@ -2963,7 +2964,7 @@ class JaisModel(Model):
assert False
self.width_scale = 1.0
if 'mup_output_alpha' in self.hparams:
if 'mup_output_alpha' in self.hparams:
assert 'mup_width_scale' in self.hparams
self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale']
elif 'width_scale' in self.hparams:
@ -2984,13 +2985,27 @@ class JaisModel(Model):
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
# Hack to populate self.tensor_names
all(self.get_tensors())
if 'transformer.relative_pe.slopes' not in self.tensor_names:
self.gguf_writer.add_max_alibi_bias(8.0)
# else set later
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
tensors: list[tuple[str, Tensor]] = []
# we don't need these
if name.endswith((".attn.bias", "relative_pe.slopes")):
if name.endswith((".attn.bias")):
return tensors
if name.endswith(("relative_pe.slopes")):
# calculate ALiBi bias
n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
first_val = float(data_torch._data[0])
alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)
self.gguf_writer.add_max_alibi_bias(alibi_bias)
return tensors
if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")):

View file

@ -4902,8 +4902,8 @@ static void llm_load_hparams(
case LLM_ARCH_JAIS:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
// TODO: become GGUF KV parameter
hparams.f_max_alibi_bias = 8.0f;
ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1_3B; break;
case 40: model.type = e_model::MODEL_13B; break;