Address review comments by foldl.

This commit is contained in:
Heiner 2024-05-09 22:43:38 +02:00
parent d894497a96
commit ef671c693d

View file

@ -204,7 +204,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
# First operate on meta tensors to find shapes and dtypes for GGUF header.
for idx, name in enumerate(weight_names):
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta")
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
f.add_tensor_info(
@ -226,8 +226,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
for name in weight_names:
weight, scales = weights.pop(name)
tensor = convert_weight(name, weight, scales, config.experts)
tensor = maybe_permute_tensor(name, tensor, config)
tensor = convert_weight(name, weight, scales, config)
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
@ -258,7 +257,7 @@ def from_numpy(array):
return torch.from_numpy(array)
def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None):
def convert_weight(name, weight, scales, config, dtype=torch.float32, device=None):
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
weight = from_numpy(weight).to(device=device, dtype=dtype)
if scales is not None:
@ -271,32 +270,19 @@ def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=No
else:
weight = weight * scale
# Transpose linear matrix
if len(weight.shape) >= 2 and "token_embd" not in name:
if name == "token_embd":
weight *= config.embedding_multiplier_scale
elif len(weight.shape) >= 2:
# Transpose linear matrix
weight = weight.transpose(-1, -2)
if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
weight = weight[experts] # gather.
weight = weight[config.experts] # gather.
return weight
def maybe_permute_tensor(name, tensor, config):
def permute(weights, n_head):
return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape)
)
if name.endswith("attn_k"):
return permute(tensor, config.num_key_value_heads)
elif name.endswith("attn_q"):
return permute(tensor, config.num_attention_heads)
return tensor
def extract_vocabulary_from_model(vocab):
tokens = []
scores = []