Address review comments by foldl.
This commit is contained in:
parent
d894497a96
commit
ef671c693d
1 changed files with 9 additions and 23 deletions
|
@ -204,7 +204,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
|
||||||
# First operate on meta tensors to find shapes and dtypes for GGUF header.
|
# First operate on meta tensors to find shapes and dtypes for GGUF header.
|
||||||
for idx, name in enumerate(weight_names):
|
for idx, name in enumerate(weight_names):
|
||||||
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
|
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
|
||||||
meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta")
|
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
|
||||||
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
|
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
|
||||||
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
|
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
|
||||||
f.add_tensor_info(
|
f.add_tensor_info(
|
||||||
|
@ -226,8 +226,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
|
||||||
|
|
||||||
for name in weight_names:
|
for name in weight_names:
|
||||||
weight, scales = weights.pop(name)
|
weight, scales = weights.pop(name)
|
||||||
tensor = convert_weight(name, weight, scales, config.experts)
|
tensor = convert_weight(name, weight, scales, config)
|
||||||
tensor = maybe_permute_tensor(name, tensor, config)
|
|
||||||
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
|
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
|
||||||
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
|
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
|
||||||
|
|
||||||
|
@ -258,7 +257,7 @@ def from_numpy(array):
|
||||||
return torch.from_numpy(array)
|
return torch.from_numpy(array)
|
||||||
|
|
||||||
|
|
||||||
def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None):
|
def convert_weight(name, weight, scales, config, dtype=torch.float32, device=None):
|
||||||
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
|
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
|
||||||
weight = from_numpy(weight).to(device=device, dtype=dtype)
|
weight = from_numpy(weight).to(device=device, dtype=dtype)
|
||||||
if scales is not None:
|
if scales is not None:
|
||||||
|
@ -271,32 +270,19 @@ def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=No
|
||||||
else:
|
else:
|
||||||
weight = weight * scale
|
weight = weight * scale
|
||||||
|
|
||||||
|
if name == "token_embd":
|
||||||
|
weight *= config.embedding_multiplier_scale
|
||||||
|
elif len(weight.shape) >= 2:
|
||||||
# Transpose linear matrix
|
# Transpose linear matrix
|
||||||
if len(weight.shape) >= 2 and "token_embd" not in name:
|
|
||||||
weight = weight.transpose(-1, -2)
|
weight = weight.transpose(-1, -2)
|
||||||
|
|
||||||
|
|
||||||
if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
|
if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
|
||||||
weight = weight[experts] # gather.
|
weight = weight[config.experts] # gather.
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def maybe_permute_tensor(name, tensor, config):
|
|
||||||
def permute(weights, n_head):
|
|
||||||
return (
|
|
||||||
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
|
||||||
.swapaxes(1, 2)
|
|
||||||
.reshape(weights.shape)
|
|
||||||
)
|
|
||||||
|
|
||||||
if name.endswith("attn_k"):
|
|
||||||
return permute(tensor, config.num_key_value_heads)
|
|
||||||
elif name.endswith("attn_q"):
|
|
||||||
return permute(tensor, config.num_attention_heads)
|
|
||||||
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def extract_vocabulary_from_model(vocab):
|
def extract_vocabulary_from_model(vocab):
|
||||||
tokens = []
|
tokens = []
|
||||||
scores = []
|
scores = []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue