diff --git a/convert_grok.py b/convert_grok.py index 714a216fc..05a304407 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -197,12 +197,20 @@ def get_dtype_and_ggml_type(name, tensor, ggml_type): def dump_state_dict(f, ggml_type, input_dir, config): - weight_names = get_weight_names(config.num_hidden_layers) weights = {} - # First operate on meta tensors to find shapes and dtypes for GGUF header. - for idx, name in enumerate(weight_names): - weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000") + # Load weights in file order (mmap'ed). + for idx, name in enumerate(get_weight_names(config.num_hidden_layers)): + weights[name] = get_weights(f"{input_dir}/tensor{idx:05}_000") + + logging.debug("Loaded %i files", len(weights)) + + # But write in layer order. + weight_names = get_weight_names(config.num_hidden_layers, lexicographic=False) + + # Operate on meta tensors to find shapes and dtypes for GGUF header. + for name in weight_names: + weight, scales = weights[name] meta_tensor = convert_weight(name, weight, scales, config, device="meta") dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type) quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type) @@ -213,8 +221,6 @@ def dump_state_dict(f, ggml_type, input_dir, config): quantized_meta_tensor.nbytes, tensor_ggml_type, ) - weights[name] = weight, scales - logging.debug("Loaded %i files", len(weight_names)) f.write_header_to_file() f.write_kv_data_to_file() @@ -244,7 +250,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): except NameError: pass - if len(tensor_info) != len(weight_names): + if weights: logging.warning("Not all tensors are converted") @@ -293,8 +299,10 @@ def extract_vocabulary_from_model(vocab): return tokens, scores, toktypes -def get_weight_names(num_hidden_layers=64): - """Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files.""" +def get_weight_names(num_hidden_layers=64, lexicographic=True): + """Return Grok-1 weight names. + + If `lexicographic` is set, the order is as in the tensor#####_000 files.""" weight_names = [ gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD], @@ -317,7 +325,10 @@ def get_weight_names(num_hidden_layers=64): ) layers = [str(bid) for bid in range(64)] - layers.sort() # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ... + + if lexicographic: + # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ... + layers.sort() for bid in layers[:num_hidden_layers]: for key in layer: