Update exporter and support scaling
This commit is contained in:
parent
dc65707130
commit
87c518bb3d
2 changed files with 92 additions and 31 deletions
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
|
@ -14,8 +15,10 @@ import torch
|
|||
class UnquantizedDataType:
|
||||
name: str
|
||||
|
||||
DT_F16 = UnquantizedDataType('F16')
|
||||
DT_F32 = UnquantizedDataType('F32')
|
||||
|
||||
DT_F16 = UnquantizedDataType("F16")
|
||||
DT_F32 = UnquantizedDataType("F32")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QuantizedDataType:
|
||||
|
@ -23,6 +26,7 @@ class QuantizedDataType:
|
|||
have_addends: bool
|
||||
have_g_idx: bool
|
||||
|
||||
|
||||
DataType = UnquantizedDataType
|
||||
|
||||
DATA_TYPE_TO_FTYPE: dict[DataType, int] = {
|
||||
|
@ -35,17 +39,28 @@ DATA_TYPE_TO_NUMPY: dict[DataType, np.dtype[Any]] = {
|
|||
DT_F32: np.dtype(np.float32),
|
||||
}
|
||||
|
||||
NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
|
||||
NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {
|
||||
dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()
|
||||
}
|
||||
|
||||
HF_SUBLAYER_TO_GGML = {
|
||||
"self_attn.q_proj": "attention.wq.weight",
|
||||
"self_attn.k_proj": "attention.wk.weight",
|
||||
"self_attn.v_proj": "attention.wv.weight",
|
||||
"self_attn.o_proj": "attention.wo.weight",
|
||||
# "embed_tokens.weight": "tok_embeddings.weight",
|
||||
# "norm.weight": "norm.weight",
|
||||
# "lm_head.weight": "output.weight",
|
||||
# "mlp.gate_proj": "feed_forward.w1.weight",
|
||||
# "mlp.down_proj": "feed_forward.w2.weight",
|
||||
# "mlp.up_proj": "feed_forward.w3.weight",
|
||||
# "input_layernorm": "attention_norm.weight",
|
||||
# "post_attention_layernorm": "ffn_norm.weight",
|
||||
}
|
||||
|
||||
|
||||
def translate_tensor_name(t):
|
||||
match = re.match(r'.*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight', t)
|
||||
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
|
||||
if match:
|
||||
nn = match.group(1)
|
||||
sub_layer = match.group(2)
|
||||
|
@ -54,50 +69,85 @@ def translate_tensor_name(t):
|
|||
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
|
||||
if sub_layer_renamed is None:
|
||||
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
output_string = f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.lora{lora_type}"
|
||||
return output_string
|
||||
else:
|
||||
print(f"Error: unrecognized tensor {t}")
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
def write_file_header(fout):
|
||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||
fout.write(struct.pack("i", 1)) # file version
|
||||
|
||||
def write_file_header(fout, params):
|
||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||
fout.write(struct.pack("i", 1)) # file version
|
||||
fout.write(struct.pack("ii", params["r"], params["lora_alpha"]))
|
||||
|
||||
|
||||
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None:
|
||||
sname = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]]))
|
||||
sname = name.encode("utf-8")
|
||||
fout.write(
|
||||
struct.pack(
|
||||
"iii",
|
||||
len(shape),
|
||||
len(sname),
|
||||
DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]],
|
||||
)
|
||||
)
|
||||
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
|
||||
fout.write(sname)
|
||||
fout.seek((fout.tell() + 31) & -32)
|
||||
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: python {sys.argv[0]} adapter_model.bin [ggml_adapter_model.bin]")
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
print(f"Usage: python {sys.argv[0]} <path>")
|
||||
print(
|
||||
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
input_path = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_path = sys.argv[2]
|
||||
else:
|
||||
output_filename = f"ggml_{os.path.basename(input_path)}"
|
||||
output_path = os.path.join(os.path.dirname(input_path), output_filename)
|
||||
input_json = os.path.join(sys.argv[1], "adapter_config.json")
|
||||
input_model = os.path.join(sys.argv[1], "adapter_model.bin")
|
||||
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
|
||||
|
||||
model = torch.load(input_path, map_location="cpu")
|
||||
model = torch.load(input_model, map_location="cpu")
|
||||
|
||||
with open(input_json, "r") as f:
|
||||
params = json.load(f)
|
||||
|
||||
if params["peft_type"] != "LORA":
|
||||
print(f"Error: unsupported adapter type {params['peft_type']}, expected LORA")
|
||||
sys.exit(1)
|
||||
|
||||
if params["fan_in_fan_out"] == True:
|
||||
print("Error: param fan_in_fan_out is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
if params["bias"] is not None and params["bias"] != "none":
|
||||
print("Error: param bias is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
# TODO: these seem to be layers that have been trained but without lora.
|
||||
# doesn't seem widely used but eventually should be supported
|
||||
if params["modules_to_save"] is not None and len(params["modules_to_save"]) > 0:
|
||||
print("Error: param modules_to_save is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
with open(output_path, "wb") as fout:
|
||||
write_file_header(fout)
|
||||
fout.truncate()
|
||||
|
||||
write_file_header(fout, params)
|
||||
for k, v in model.items():
|
||||
# since ggml doesn't always support other types for the second operand,
|
||||
# the tensors are always converted and exported as f32
|
||||
t = v.float().numpy()
|
||||
v = v.float()
|
||||
t = v.numpy()
|
||||
if "lora_A" in k:
|
||||
t = t.T
|
||||
print(f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
||||
print(
|
||||
f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB"
|
||||
)
|
||||
write_tensor_header(fout, translate_tensor_name(k), t.shape, t.dtype)
|
||||
t.tofile(fout)
|
||||
|
||||
print(f"Converted {input_path} to {output_path}")
|
||||
print(f"Converted {input_json} and {input_model} to {output_path}")
|
||||
|
|
23
llama.cpp
23
llama.cpp
|
@ -1789,6 +1789,15 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
|
|||
}
|
||||
}
|
||||
|
||||
int32_t lora_r;
|
||||
int32_t lora_alpha;
|
||||
fin.read((char *) &lora_r, sizeof(lora_r));
|
||||
fin.read((char *) &lora_alpha, sizeof(lora_alpha));
|
||||
float scaling = (float)lora_alpha / (float)lora_r;
|
||||
|
||||
fprintf(stderr, "%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
||||
|
||||
|
||||
// create a temporary ggml context to store the lora tensors
|
||||
std::vector<uint8_t> buf(1024 * 1024 * 100);
|
||||
struct ggml_init_params params;
|
||||
|
@ -1890,11 +1899,13 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
|
|||
// w = w + BA*s
|
||||
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA);
|
||||
|
||||
//if (true) {
|
||||
// ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, 1.0f);
|
||||
// BA = ggml_scale(lora_ctx, BA, scale_tensor);
|
||||
//}
|
||||
ggml_tensor * r = ggml_add(lora_ctx, tensor, BA);
|
||||
if (scaling != 1.0f) {
|
||||
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
|
||||
BA = ggml_scale(lora_ctx, BA, scale_tensor);
|
||||
}
|
||||
|
||||
ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA);
|
||||
//ggml_tensor * r = ggml_add(lora_ctx, tensor, BA);
|
||||
//r = ggml_cpy(lora_ctx, r, tensor);
|
||||
|
||||
struct ggml_cgraph gf = ggml_build_forward(r);
|
||||
|
@ -1902,7 +1913,7 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
|
|||
ggml_graph_compute(lora_ctx, &gf);
|
||||
|
||||
// hack until ggml_cpy supports quantized tensors
|
||||
memcpy(tensor->data, r->data, ggml_nbytes(tensor));
|
||||
// memcpy(tensor->data, r->data, ggml_nbytes(tensor));
|
||||
|
||||
// we won't need these tensors again, reset the context to save memory
|
||||
ggml_free(lora_ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue