lora : add support for non-llama models (#3333)
* lora : add support for non-llama models ggml-ci * avoid leaking ggml_context on failure cleanup ggml-ci * lora : allow 1d tensors * lora : include embd and output layers in size calculation * fix style
This commit is contained in:
parent
8a5be3bd58
commit
c6c4fc081c
3 changed files with 114 additions and 106 deletions
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
from typing import Any, BinaryIO, Sequence
|
||||
|
@ -11,43 +10,15 @@ from typing import Any, BinaryIO, Sequence
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
||||
import gguf
|
||||
|
||||
|
||||
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
|
||||
|
||||
|
||||
HF_SUBLAYER_TO_GGML = {
|
||||
"self_attn.q_proj": "attn_q",
|
||||
"self_attn.k_proj": "attn_k",
|
||||
"self_attn.v_proj": "attn_v",
|
||||
"self_attn.o_proj": "attn_output",
|
||||
"mlp.gate_proj": "ffn_gate",
|
||||
"mlp.down_proj": "ffn_down",
|
||||
"mlp.up_proj": "ffn_up",
|
||||
"input_layernorm": "attn_norm",
|
||||
"post_attention_layernorm": "ffn_norm",
|
||||
}
|
||||
|
||||
|
||||
def translate_tensor_name(t: str) -> str:
|
||||
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
|
||||
if match:
|
||||
nn = match.group(1)
|
||||
sub_layer = match.group(2)
|
||||
lora_type = match.group(3)
|
||||
|
||||
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
|
||||
if sub_layer_renamed is None:
|
||||
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
|
||||
sys.exit(1)
|
||||
|
||||
output_string = (
|
||||
f"blk.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}"
|
||||
)
|
||||
return output_string
|
||||
else:
|
||||
print(f"Error: unrecognized tensor {t}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
|
||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||
fout.write(struct.pack("i", 1)) # file version
|
||||
|
@ -61,9 +32,7 @@ def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
|
|||
fout.write(struct.pack("i", int(params["lora_alpha"])))
|
||||
|
||||
|
||||
def write_tensor_header(
|
||||
self, name: str, shape: Sequence[int], data_type: np.dtype[Any]
|
||||
) -> None:
|
||||
def write_tensor_header(fout: BinaryIO, name: str, shape: Sequence[int], data_type: np.dtype[Any]) -> None:
|
||||
sname = name.encode("utf-8")
|
||||
fout.write(
|
||||
struct.pack(
|
||||
|
@ -78,11 +47,12 @@ def write_tensor_header(
|
|||
fout.seek((fout.tell() + 31) & -32)
|
||||
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
print(f"Usage: python {sys.argv[0]} <path>")
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: python {sys.argv[0]} <path> [arch]")
|
||||
print(
|
||||
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
|
||||
)
|
||||
print(f"Arch must be one of {list(gguf.MODEL_ARCH_NAMES.values())} (default: llama)")
|
||||
sys.exit(1)
|
||||
|
||||
input_json = os.path.join(sys.argv[1], "adapter_config.json")
|
||||
|
@ -90,6 +60,14 @@ input_model = os.path.join(sys.argv[1], "adapter_model.bin")
|
|||
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
|
||||
|
||||
model = torch.load(input_model, map_location="cpu")
|
||||
arch_name = sys.argv[2] if len(sys.argv) == 3 else "llama"
|
||||
|
||||
if arch_name not in gguf.MODEL_ARCH_NAMES.values():
|
||||
print(f"Error: unsupported architecture {arch_name}")
|
||||
sys.exit(1)
|
||||
|
||||
arch = list(gguf.MODEL_ARCH_NAMES.keys())[list(gguf.MODEL_ARCH_NAMES.values()).index(arch_name)]
|
||||
name_map = gguf.TensorNameMap(arch, 200) # 200 layers ought to be enough for anyone
|
||||
|
||||
with open(input_json, "r") as f:
|
||||
params = json.load(f)
|
||||
|
@ -117,6 +95,7 @@ with open(output_path, "wb") as fout:
|
|||
|
||||
write_file_header(fout, params)
|
||||
for k, v in model.items():
|
||||
orig_k = k
|
||||
if k.endswith(".default.weight"):
|
||||
k = k.replace(".default.weight", ".weight")
|
||||
if k in ["llama_proj.weight", "llama_proj.bias"]:
|
||||
|
@ -129,7 +108,32 @@ with open(output_path, "wb") as fout:
|
|||
v = v.float()
|
||||
|
||||
t = v.detach().numpy()
|
||||
tname = translate_tensor_name(k)
|
||||
|
||||
prefix = "base_model.model."
|
||||
if k.startswith(prefix):
|
||||
k = k[len(prefix) :]
|
||||
|
||||
lora_suffixes = (".lora_A.weight", ".lora_B.weight")
|
||||
if k.endswith(lora_suffixes):
|
||||
suffix = k[-len(lora_suffixes[0]):]
|
||||
k = k[: -len(lora_suffixes[0])]
|
||||
else:
|
||||
print(f"Error: unrecognized tensor name {orig_k}")
|
||||
sys.exit(1)
|
||||
|
||||
tname = name_map.get_name(k)
|
||||
if tname is None:
|
||||
print(f"Error: could not map tensor name {orig_k}")
|
||||
print(" Note: the arch parameter must be specified if the model is not llama")
|
||||
sys.exit(1)
|
||||
|
||||
if suffix == ".lora_A.weight":
|
||||
tname += ".weight.loraA"
|
||||
elif suffix == ".lora_B.weight":
|
||||
tname += ".weight.loraB"
|
||||
else:
|
||||
assert False
|
||||
|
||||
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
||||
write_tensor_header(fout, tname, t.shape, t.dtype)
|
||||
t.tofile(fout)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue