diff --git a/convert_grok.py b/convert_grok.py index d71d647dd..7d0dbfc77 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -204,7 +204,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): # First operate on meta tensors to find shapes and dtypes for GGUF header. for idx, name in enumerate(weight_names): weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000") - meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta") + meta_tensor = convert_weight(name, weight, scales, config, device="meta") dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type) quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type) f.add_tensor_info( @@ -226,8 +226,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): for name in weight_names: weight, scales = weights.pop(name) - tensor = convert_weight(name, weight, scales, config.experts) - tensor = maybe_permute_tensor(name, tensor, config) + tensor = convert_weight(name, weight, scales, config) _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy() @@ -258,7 +257,7 @@ def from_numpy(array): return torch.from_numpy(array) -def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None): +def convert_weight(name, weight, scales, config, dtype=torch.float32, device=None): # copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1 weight = from_numpy(weight).to(device=device, dtype=dtype) if scales is not None: @@ -271,32 +270,19 @@ def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=No else: weight = weight * scale - # Transpose linear matrix - if len(weight.shape) >= 2 and "token_embd" not in name: + if name == "token_embd": + weight *= config.embedding_multiplier_scale + elif len(weight.shape) >= 2: + # Transpose linear matrix weight = weight.transpose(-1, -2) + if name.endswith("ffn_gate_inp") or name.endswith("_exps"): - weight = weight[experts] # gather. + weight = weight[config.experts] # gather. return weight -def maybe_permute_tensor(name, tensor, config): - def permute(weights, n_head): - return ( - weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape) - ) - - if name.endswith("attn_k"): - return permute(tensor, config.num_key_value_heads) - elif name.endswith("attn_q"): - return permute(tensor, config.num_attention_heads) - - return tensor - - def extract_vocabulary_from_model(vocab): tokens = [] scores = []