convert_hf : simplify modify_tensors for InternLM2

* convert_lora : lazy conversion

* llama : load and use alpha from LoRA adapters
This commit is contained in:
Francis Couture-Harpin 2024-07-15 02:35:06 -04:00
parent 9d96328bdf
commit 8956543c09
4 changed files with 123 additions and 66 deletions

View file

@ -8,9 +8,10 @@ import logging
import argparse
import os
import sys
import json
from math import prod
from pathlib import Path
from types import EllipsisType
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
import torch
@ -22,7 +23,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
import gguf
# reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import Model
from convert_hf_to_gguf import LazyTorchTensor, Model
logger = logging.getLogger("lora-to-gguf")
@ -35,37 +36,45 @@ class PartialLoraTensor:
# magic to support tensor shape modifications and splitting
class LoraTorchTensor:
_lora_A: Tensor
_lora_B: Tensor
_lora_A: Tensor # (n_rank, row_size)
_lora_B: Tensor # (col_size, n_rank)
_rank: int
def __init__(self, A: Tensor, B: Tensor):
assert len(A.shape) == len(B.shape)
assert A.shape[-2] == B.shape[-1]
if A.dtype != B.dtype:
A = A.to(torch.float32)
B = B.to(torch.float32)
self._lora_A = A
self._lora_B = B
assert self._lora_A.shape[-2] == self._lora_B.shape[-1]
self._rank = self._lora_B.shape[-1]
self._rank = B.shape[-1]
def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
return (self._lora_A, self._lora_B)
def __getitem__(
self,
indices: (
SupportsIndex
| slice
| tuple[SupportsIndex | slice | EllipsisType | Tensor, ...]
| tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature
),
) -> LoraTorchTensor:
shape = self.shape
if isinstance(indices, (SupportsIndex, slice)):
if isinstance(indices, SupportsIndex):
if len(shape) > 2:
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
else:
raise NotImplementedError
raise NotImplementedError # can't return a vector
elif isinstance(indices, slice):
if len(shape) > 2:
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
else:
return LoraTorchTensor(self._lora_A, self._lora_B[indices])
elif isinstance(indices, tuple):
assert len(indices) > 0
if isinstance(indices[-1], EllipsisType):
if indices[-1] is Ellipsis:
return self[indices[:-1]]
# expand ellipsis
indices = tuple(
@ -73,7 +82,7 @@ class LoraTorchTensor:
for v in (
(
(slice(None, None) for _ in range(len(indices) - 1))
if isinstance(i, EllipsisType)
if i is Ellipsis
else (i,)
)
for i in indices
@ -85,11 +94,14 @@ class LoraTorchTensor:
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
# TODO: make sure this is correct
# lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1])
indices_A = (
*(
0 if isinstance(i, SupportsIndex) else slice(None, None)
for i in indices[:-2]
(
j.__index__() % self._lora_A.shape[i]
if isinstance(j, SupportsIndex)
else slice(None, None)
)
for i, j in enumerate(indices[:-2])
),
slice(None, None),
indices[-1],
@ -97,7 +109,7 @@ class LoraTorchTensor:
indices_B = indices[:-1]
return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
else:
raise NotImplementedError
raise NotImplementedError # unknown indice type
@property
def dtype(self) -> torch.dtype:
@ -106,23 +118,37 @@ class LoraTorchTensor:
@property
def shape(self) -> tuple[int, ...]:
assert len(self._lora_A.shape) == len(self._lora_B.shape)
return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
def size(self, dim=None):
assert dim is None
return self.shape
def reshape(self, *shape: int | tuple[int]) -> LoraTorchTensor:
def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
if isinstance(shape[0], tuple):
new_shape: tuple[int] = shape[0]
new_shape: tuple[int, ...] = shape[0]
else:
new_shape = cast(tuple[int], shape)
new_shape = cast(tuple[int, ...], shape)
orig_shape = self.shape
if len(new_shape) < 2:
raise NotImplementedError # can't become a vector
# expand -1 in the shape
if any(dim == -1 for dim in new_shape):
n_elems = prod(orig_shape)
n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
assert n_elems % n_new_elems == 0
new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
if new_shape[-1] != orig_shape[-1]:
raise NotImplementedError
raise NotImplementedError # can't reshape the row size trivially
shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
shape_B = (*new_shape[:-1], self._rank)
return LoraTorchTensor(
self._lora_A.reshape((*(1 for _ in new_shape[:-2]), *self._lora_A.shape[-2:])),
self._lora_B.reshape((*new_shape[:-1], self._rank)),
self._lora_A.reshape(shape_A),
self._lora_B.reshape(shape_B),
)
def reshape_as(self, other: Tensor) -> LoraTorchTensor:
@ -134,12 +160,15 @@ class LoraTorchTensor:
def permute(self, *dims: int) -> LoraTorchTensor:
shape = self.shape
dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
if dims[-1] == -2 and dims[-2] == -1:
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
else:
assert dims[-1] == -1
if dims[-1] == -1:
# TODO: support higher dimensional A shapes bigger than 1
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
else:
# TODO: compose the above two
raise NotImplementedError
def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
shape = self.shape
@ -181,11 +210,13 @@ class LoraTorchTensor:
torch.cat([a._lora_A for a in args[0]], dim),
torch.cat([b._lora_B for b in args[0]], dim),
)
else:
elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
return LoraTorchTensor(
args[0][0]._lora_A, # TODO: is this correct? (can't cat over the rank)
args[0][0]._lora_A,
torch.cat([b._lora_B for b in args[0]], dim),
)
else:
raise NotImplementedError
else:
raise NotImplementedError
@ -205,13 +236,17 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
help="model is executed on big endian machine",
)
parser.add_argument(
"--no-lazy", action="store_true",
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
)
parser.add_argument(
"--verbose", action="store_true",
help="increase output verbosity",
@ -237,13 +272,16 @@ if __name__ == '__main__':
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"auto": gguf.LlamaFileType.GUESSED,
}
ftype = ftype_map[args.outtype]
dir_base_model = args.base
dir_lora = args.lora_path
input_json = os.path.join(dir_lora, "adapter_config.json")
input_model = os.path.join(dir_lora, "adapter_model.safetensors")
dir_base_model: Path = args.base
dir_lora: Path = args.lora_path
lora_config = dir_lora / "adapter_config.json"
input_model = dir_lora / "adapter_model.safetensors"
if args.outfile is not None:
fname_out = args.outfile
else:
@ -276,6 +314,8 @@ if __name__ == '__main__':
tensor_map: dict[str, PartialLoraTensor] = {}
for name, tensor in lora_model.items():
if self.lazy:
tensor = LazyTorchTensor.from_eager(tensor)
base_name = get_base_tensor_name(name)
is_lora_a = ".lora_A.weight" in name
is_lora_b = ".lora_B.weight" in name
@ -305,16 +345,30 @@ if __name__ == '__main__':
dest = super().modify_tensors(data_torch, name, bid)
for dest_name, dest_data in dest:
assert isinstance(dest_data, LoraTorchTensor)
# logger.info(f"{orig_name} --> {dest_name}")
yield (dest_name + ".lora_a", dest_data._lora_A)
yield (dest_name + ".lora_b", dest_data._lora_B)
lora_a, lora_b = dest_data.get_lora_A_B()
model_instance = LoraModel(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
yield (dest_name + ".lora_a", lora_a)
yield (dest_name + ".lora_b", lora_b)
model_instance = LoraModel(
dir_base_model,
ftype,
fname_out,
is_big_endian=args.bigendian,
use_temp_file=False,
eager=args.no_lazy,
model_name=None,
)
logger.info("Set model parameters")
model_instance.set_gguf_parameters()
# adapter_config = json.load(input_json)
with open(lora_config, "r") as f:
lparams: dict[str, Any] = json.load(f)
alpha = lparams["lora_alpha"]
model_instance.gguf_writer.add_string("training.type", "finetune_lora")
model_instance.gguf_writer.add_float32("training.lora.alpha", float(alpha))
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
logger.info("Exporting model...")