Write tensors in layer order.

This commit is contained in:
Heiner 2024-05-23 15:07:27 +02:00
parent 60b29ea6e4
commit 0a1ef1127f

View file

@ -197,12 +197,20 @@ def get_dtype_and_ggml_type(name, tensor, ggml_type):
def dump_state_dict(f, ggml_type, input_dir, config):
weight_names = get_weight_names(config.num_hidden_layers)
weights = {}
# First operate on meta tensors to find shapes and dtypes for GGUF header.
for idx, name in enumerate(weight_names):
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
# Load weights in file order (mmap'ed).
for idx, name in enumerate(get_weight_names(config.num_hidden_layers)):
weights[name] = get_weights(f"{input_dir}/tensor{idx:05}_000")
logging.debug("Loaded %i files", len(weights))
# But write in layer order.
weight_names = get_weight_names(config.num_hidden_layers, lexicographic=False)
# Operate on meta tensors to find shapes and dtypes for GGUF header.
for name in weight_names:
weight, scales = weights[name]
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type)
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
@ -213,8 +221,6 @@ def dump_state_dict(f, ggml_type, input_dir, config):
quantized_meta_tensor.nbytes,
tensor_ggml_type,
)
weights[name] = weight, scales
logging.debug("Loaded %i files", len(weight_names))
f.write_header_to_file()
f.write_kv_data_to_file()
@ -244,7 +250,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
except NameError:
pass
if len(tensor_info) != len(weight_names):
if weights:
logging.warning("Not all tensors are converted")
@ -293,8 +299,10 @@ def extract_vocabulary_from_model(vocab):
return tokens, scores, toktypes
def get_weight_names(num_hidden_layers=64):
"""Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files."""
def get_weight_names(num_hidden_layers=64, lexicographic=True):
"""Return Grok-1 weight names.
If `lexicographic` is set, the order is as in the tensor#####_000 files."""
weight_names = [
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD],
@ -317,7 +325,10 @@ def get_weight_names(num_hidden_layers=64):
)
layers = [str(bid) for bid in range(64)]
layers.sort() # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...
if lexicographic:
# Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...
layers.sort()
for bid in layers[:num_hidden_layers]:
for key in layer: