Fix layer order.
This commit is contained in:
parent
9a0629d545
commit
f177b6596c
1 changed files with 13 additions and 12 deletions
|
@ -126,7 +126,7 @@ def get_weights(fn):
|
|||
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
|
||||
# equivalent to ggml_quantize_q8_0 in ggml.c
|
||||
assert tensor.shape[1] % GGML_QK8_0 == 0
|
||||
tensor = tensor.view(-1, GGML_QK8_0)
|
||||
tensor = tensor.reshape(-1, GGML_QK8_0)
|
||||
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
|
||||
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
|
||||
# add scale into each block
|
||||
|
@ -152,7 +152,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
|
|||
def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
|
||||
# equivalent to ggml_quantize_q4_1 in ggml.c
|
||||
assert tensor.shape[1] % GGML_QK4_1 == 0
|
||||
tensor = tensor.view(-1, GGML_QK4_1)
|
||||
tensor = tensor.reshape(-1, GGML_QK4_1)
|
||||
abs_max_indices = tensor.max(dim=-1, keepdim=True).indices
|
||||
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
|
||||
abs_min_indices = tensor.min(dim=-1, keepdim=True).indices
|
||||
|
@ -185,15 +185,13 @@ def maybe_quantize_tensor(tensor, ggml_type):
|
|||
raise NotImplementedError(f"Cannot quantize tensor of dtype {tensor.dtype} ({ggml_type})")
|
||||
|
||||
|
||||
def get_dtype_and_ggml_type(tensor, ggml_type):
|
||||
if tensor.ndim in (2, 3):
|
||||
def get_dtype_and_ggml_type(name, tensor, ggml_type):
|
||||
if tensor.ndim in (2, 3) and "ffn_gate_inp" not in name:
|
||||
if tensor.shape[1] % GGML_QK8_0 == 0:
|
||||
return np.int8, ggml_type
|
||||
else:
|
||||
return np.float16, gguf.GGMLQuantizationType.F16
|
||||
else:
|
||||
# 1d weight: convert it to float32
|
||||
assert tensor.ndim == 1, tensor
|
||||
return np.float32, gguf.GGMLQuantizationType.F32
|
||||
|
||||
|
||||
|
@ -205,7 +203,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
|
|||
for idx, name in enumerate(weight_names):
|
||||
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
|
||||
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
|
||||
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
|
||||
dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type)
|
||||
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
|
||||
f.add_tensor_info(
|
||||
f"{name}.weight",
|
||||
|
@ -227,7 +225,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
|
|||
for name in weight_names:
|
||||
weight, scales = weights.pop(name)
|
||||
tensor = convert_weight(name, weight, scales, config)
|
||||
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
|
||||
_, tensor_ggml_type = get_dtype_and_ggml_type(name, tensor, ggml_type)
|
||||
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
|
||||
|
||||
logging.info(
|
||||
|
@ -317,7 +315,10 @@ def get_weight_names(num_hidden_layers=64):
|
|||
gguf.MODEL_TENSOR.FFN_GATE_INP,
|
||||
)
|
||||
|
||||
for bid in range(num_hidden_layers):
|
||||
layers = [str(bid) for bid in range(64)]
|
||||
layers.sort() # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...
|
||||
|
||||
for bid in layers[:num_hidden_layers]:
|
||||
for key in layer:
|
||||
weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid))
|
||||
|
||||
|
@ -333,7 +334,6 @@ def convert_grok(args, vocab, ggml_type):
|
|||
return _ffn_size
|
||||
|
||||
config = {
|
||||
"vocab_size": 128 * 1024,
|
||||
"hidden_act": "gelu",
|
||||
"pad_token_id": 0,
|
||||
"eos_token_id": 2,
|
||||
|
@ -366,8 +366,7 @@ def convert_grok(args, vocab, ggml_type):
|
|||
|
||||
f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)
|
||||
|
||||
f.add_name("grok")
|
||||
f.add_vocab_size(config.vocab_size)
|
||||
f.add_name("grok-1")
|
||||
f.add_context_length(config.max_position_embeddings)
|
||||
f.add_embedding_length(config.hidden_size)
|
||||
f.add_block_count(config.num_hidden_layers)
|
||||
|
@ -389,6 +388,8 @@ def convert_grok(args, vocab, ggml_type):
|
|||
f.add_token_scores(scores)
|
||||
f.add_token_types(toktypes)
|
||||
|
||||
f.add_quantization_version(ggml_type)
|
||||
|
||||
dump_state_dict(f, ggml_type, args.input_dir, config)
|
||||
f.close()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue