convert : reduce unnecessary variables in Params

This commit is contained in:
Cebtenzzre 2023-09-06 13:00:04 -04:00
parent dcb058ce5d
commit 281b26e647

View file

@ -210,23 +210,12 @@ class Params:
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path)) config = json.load(open(config_path))
n_vocab = config["vocab_size"]
n_embd = config["hidden_size"]
n_layer = config["num_hidden_layers"]
n_ff = config["intermediate_size"]
n_head = config["num_attention_heads"]
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
f_norm_eps = config["rms_norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
rope_scaling = config.get("rope_scaling") rope_scaling = config.get("rope_scaling")
if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear": if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
f_rope_scale = config["rope_scaling"].get("factor") f_rope_scale = config["rope_scaling"].get("factor")
else: else:
f_rope_scale = None f_rope_scale = None
n_mult = Params.find_n_mult(n_ff, n_embd)
if "max_sequence_length" in config: if "max_sequence_length" in config:
n_ctx = config["max_sequence_length"] n_ctx = config["max_sequence_length"]
elif "max_position_embeddings" in config: elif "max_position_embeddings" in config:
@ -236,16 +225,16 @@ class Params:
"Suggestion: provide 'config.json' of the model in the same directory containing model files.") "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
return Params( return Params(
n_vocab = n_vocab, n_vocab = config["vocab_size"],
n_embd = n_embd, n_embd = config["hidden_size"],
n_mult = n_mult, n_mult = Params.find_n_mult(n_ff, n_embd),
n_layer = n_layer, n_layer = config["num_hidden_layers"],
n_ctx = n_ctx, n_ctx = n_ctx,
n_ff = n_ff, n_ff = config["intermediate_size"],
n_head = n_head, n_head = config["num_attention_heads"],
n_head_kv = n_head_kv, n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head,
f_norm_eps = f_norm_eps, f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = f_rope_freq_base, f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None,
f_rope_scale = f_rope_scale, f_rope_scale = f_rope_scale,
) )
@ -255,16 +244,6 @@ class Params:
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path)) config = json.load(open(config_path))
n_vocab = config["vocab_size"] if "vocab_size" in config else -1
n_embd = config["dim"]
n_layer = config["n_layers"]
n_mult = config["multiple_of"]
n_ff = -1
n_head = config["n_heads"]
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
f_norm_eps = config["norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
# hack to determine LLaMA v1 vs v2 vs CodeLlama # hack to determine LLaMA v1 vs v2 vs CodeLlama
if f_rope_freq_base and f_rope_freq_base == 1000000: if f_rope_freq_base and f_rope_freq_base == 1000000:
# CodeLlama # CodeLlama
@ -276,23 +255,17 @@ class Params:
# LLaMA v1 # LLaMA v1
n_ctx = 2048 n_ctx = 2048
if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0]
if n_ff == -1:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
return Params( return Params(
n_vocab = n_vocab, n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
n_embd = n_embd, n_embd = config["dim"],
n_mult = n_mult, n_mult = config["multiple_of"],
n_layer = n_layer, n_layer = config["n_layers"],
n_ctx = n_ctx, n_ctx = n_ctx,
n_ff = n_ff, n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
n_head = n_head, n_head = config["n_heads"],
n_head_kv = n_head_kv, n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head,
f_norm_eps = f_norm_eps, f_norm_eps = config["norm_eps"],
f_rope_freq_base = f_rope_freq_base, f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None,
) )
@staticmethod @staticmethod