Refactor lora adapter support (#8332)
* lora: load to devide buft
* add patch tensor function
* correct tensor patch
* llama_lora_adapter_apply
* correct ggml_backend_tensor_copy
* add llm_build_mm
* fix auto merge
* update based on review comments
* add convert script
* no more transpose A
* add f16 convert
* add metadata check
* add sanity check
* fix ftype
* add requirements
* fix requirements
* fix outfile
* conversion: only allow selected models
* fix types
* cuda : do not use dmmv if the tensor does not have enough cols
* llama : lora fixes
* do not disable mmap with lora
Co-authored-by: slaren <slarengh@gmail.com>
* llm_build_lora_mm_id
* convert_lora : MoE LoRA conversion support
* convert_lora : prefer safetensors, similarly to convert_hf
* convert_hf : simplify modify_tensors for InternLM2
* convert_lora : lazy conversion
* llama : load and use alpha from LoRA adapters
* llama : use llm_build_lora_mm in most model graphs
* auto scale
* Revert "auto scale"
This reverts commit 42415a4874.
* remove redundant params
* Apply suggestions from code review
Co-authored-by: slaren <slarengh@gmail.com>
* change kv metadata
* move add_type to __init__
* convert_hf : move add_type to main()
* convert_lora : use the GGUFWriter from Model instead of overwriting it
---------
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Francis Couture-Harpin <git@compilade.net>
			
			
This commit is contained in:
		
							parent
							
								
									4db8f60fe7
								
							
						
					
					
						commit
						97bdd26eee
					
				
					 12 changed files with 963 additions and 530 deletions
				
			
		|  | @ -685,7 +685,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa | ||||||
|     if (arg == "--lora") { |     if (arg == "--lora") { | ||||||
|         CHECK_ARG |         CHECK_ARG | ||||||
|         params.lora_adapter.emplace_back(argv[i], 1.0f); |         params.lora_adapter.emplace_back(argv[i], 1.0f); | ||||||
|         params.use_mmap = false; |  | ||||||
|         return true; |         return true; | ||||||
|     } |     } | ||||||
|     if (arg == "--lora-scaled") { |     if (arg == "--lora-scaled") { | ||||||
|  | @ -693,7 +692,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa | ||||||
|         const char* lora_adapter = argv[i]; |         const char* lora_adapter = argv[i]; | ||||||
|         CHECK_ARG |         CHECK_ARG | ||||||
|         params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); |         params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); | ||||||
|         params.use_mmap = false; |  | ||||||
|         return true; |         return true; | ||||||
|     } |     } | ||||||
|     if (arg == "--lora-base") { |     if (arg == "--lora-base") { | ||||||
|  | @ -2089,19 +2087,14 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par | ||||||
|     for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { |     for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { | ||||||
|         const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); |         const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); | ||||||
|         float lora_scale = std::get<1>(params.lora_adapter[i]); |         float lora_scale = std::get<1>(params.lora_adapter[i]); | ||||||
|         int err = llama_model_apply_lora_from_file(model, |         auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str()); | ||||||
|                                              lora_adapter.c_str(), |         if (adapter == nullptr) { | ||||||
|                                              lora_scale, |  | ||||||
|                                              ((i > 0) || params.lora_base.empty()) |  | ||||||
|                                                 ? NULL |  | ||||||
|                                                 : params.lora_base.c_str(), |  | ||||||
|                                              params.n_threads); |  | ||||||
|         if (err != 0) { |  | ||||||
|             fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); |             fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); | ||||||
|             llama_free(lctx); |             llama_free(lctx); | ||||||
|             llama_free_model(model); |             llama_free_model(model); | ||||||
|             return std::make_tuple(nullptr, nullptr); |             return std::make_tuple(nullptr, nullptr); | ||||||
|         } |         } | ||||||
|  |         llama_lora_adapter_set(lctx, adapter, lora_scale); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     if (params.ignore_eos) { |     if (params.ignore_eos) { | ||||||
|  |  | ||||||
|  | @ -2264,13 +2264,6 @@ class InternLM2Model(Model): | ||||||
| 
 | 
 | ||||||
|         special_vocab.add_to_gguf(self.gguf_writer) |         special_vocab.add_to_gguf(self.gguf_writer) | ||||||
| 
 | 
 | ||||||
|     def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int): |  | ||||||
|         if n_head_kv is not None and n_head != n_head_kv: |  | ||||||
|             n_head = n_head_kv |  | ||||||
|         return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) |  | ||||||
|                 .swapaxes(1, 2) |  | ||||||
|                 .reshape(weights.shape)) |  | ||||||
| 
 |  | ||||||
|     def set_gguf_parameters(self): |     def set_gguf_parameters(self): | ||||||
|         self.gguf_writer.add_name("InternLM2") |         self.gguf_writer.add_name("InternLM2") | ||||||
|         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) |         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) | ||||||
|  | @ -2290,26 +2283,22 @@ class InternLM2Model(Model): | ||||||
|     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: |     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||||
|         num_heads = self.hparams["num_attention_heads"] |         num_heads = self.hparams["num_attention_heads"] | ||||||
|         num_kv_heads = self.hparams["num_key_value_heads"] |         num_kv_heads = self.hparams["num_key_value_heads"] | ||||||
|         hidden_size = self.hparams["hidden_size"] |         n_embd = self.hparams["hidden_size"] | ||||||
|         q_per_kv = num_heads // num_kv_heads |         q_per_kv = num_heads // num_kv_heads | ||||||
|         head_dim = hidden_size // num_heads |         head_dim = n_embd // num_heads | ||||||
|         num_groups = num_heads // q_per_kv |         num_groups = num_heads // q_per_kv | ||||||
| 
 | 
 | ||||||
|         qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv" |         if bid is not None and f"model.layers.{bid}.attention.wqkv" in name: | ||||||
| 
 |  | ||||||
|         if re.match(qkv_pattern, name): |  | ||||||
|             bid = re.findall(qkv_pattern, name)[0] |  | ||||||
|             qkv = data_torch |             qkv = data_torch | ||||||
|             # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim) | 
 | ||||||
|             qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim)) |             qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd)) | ||||||
|             q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :] |             q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1] | ||||||
|  | 
 | ||||||
|             # The model weights of q and k equire additional reshape. |             # The model weights of q and k equire additional reshape. | ||||||
|             # q = self._hf_permute_qk(rearrange(q, " o g n i ->  o (g n i)").T, num_heads, num_heads) |             q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads) | ||||||
|             q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads) |             k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads) | ||||||
|             # k = self._hf_permute_qk(rearrange(k, " o g n i ->  o (g n i)").T, num_heads, num_kv_heads) |             v = v.reshape((-1, v.shape[-1])) | ||||||
|             k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads) | 
 | ||||||
|             # v = rearrange(v, " o g n i ->  o (g n i)").T |  | ||||||
|             v = v.reshape((v.shape[0], -1)).T |  | ||||||
|             return [ |             return [ | ||||||
|                 (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), |                 (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), | ||||||
|                 (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), |                 (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), | ||||||
|  | @ -3585,6 +3574,7 @@ def main() -> None: | ||||||
|                                      small_first_shard=args.no_tensor_first_split) |                                      small_first_shard=args.no_tensor_first_split) | ||||||
| 
 | 
 | ||||||
|         logger.info("Set model parameters") |         logger.info("Set model parameters") | ||||||
|  |         model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL) | ||||||
|         model_instance.set_gguf_parameters() |         model_instance.set_gguf_parameters() | ||||||
| 
 | 
 | ||||||
|         logger.info("Set model tokenizer") |         logger.info("Set model tokenizer") | ||||||
|  |  | ||||||
							
								
								
									
										374
									
								
								convert_lora_to_gguf.py
									
										
									
									
									
										Executable file
									
								
							
							
						
						
									
										374
									
								
								convert_lora_to_gguf.py
									
										
									
									
									
										Executable file
									
								
							|  | @ -0,0 +1,374 @@ | ||||||
|  | #!/usr/bin/env python3 | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | 
 | ||||||
|  | from __future__ import annotations | ||||||
|  | 
 | ||||||
|  | from dataclasses import dataclass | ||||||
|  | import logging | ||||||
|  | import argparse | ||||||
|  | import os | ||||||
|  | import sys | ||||||
|  | import json | ||||||
|  | from math import prod | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from torch import Tensor | ||||||
|  | 
 | ||||||
|  | if 'NO_LOCAL_GGUF' not in os.environ: | ||||||
|  |     sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) | ||||||
|  | import gguf | ||||||
|  | 
 | ||||||
|  | # reuse model definitions from convert_hf_to_gguf.py | ||||||
|  | from convert_hf_to_gguf import LazyTorchTensor, Model | ||||||
|  | 
 | ||||||
|  | logger = logging.getLogger("lora-to-gguf") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class PartialLoraTensor: | ||||||
|  |     A: Tensor | None = None | ||||||
|  |     B: Tensor | None = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # magic to support tensor shape modifications and splitting | ||||||
|  | class LoraTorchTensor: | ||||||
|  |     _lora_A: Tensor  # (n_rank, row_size) | ||||||
|  |     _lora_B: Tensor  # (col_size, n_rank) | ||||||
|  |     _rank: int | ||||||
|  | 
 | ||||||
|  |     def __init__(self, A: Tensor, B: Tensor): | ||||||
|  |         assert len(A.shape) == len(B.shape) | ||||||
|  |         assert A.shape[-2] == B.shape[-1] | ||||||
|  |         if A.dtype != B.dtype: | ||||||
|  |             A = A.to(torch.float32) | ||||||
|  |             B = B.to(torch.float32) | ||||||
|  |         self._lora_A = A | ||||||
|  |         self._lora_B = B | ||||||
|  |         self._rank = B.shape[-1] | ||||||
|  | 
 | ||||||
|  |     def get_lora_A_B(self) -> tuple[Tensor, Tensor]: | ||||||
|  |         return (self._lora_A, self._lora_B) | ||||||
|  | 
 | ||||||
|  |     def __getitem__( | ||||||
|  |         self, | ||||||
|  |         indices: ( | ||||||
|  |             SupportsIndex | ||||||
|  |             | slice | ||||||
|  |             | tuple[SupportsIndex | slice | Tensor, ...]  # TODO: add ellipsis in the type signature | ||||||
|  |         ), | ||||||
|  |     ) -> LoraTorchTensor: | ||||||
|  |         shape = self.shape | ||||||
|  |         if isinstance(indices, SupportsIndex): | ||||||
|  |             if len(shape) > 2: | ||||||
|  |                 return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) | ||||||
|  |             else: | ||||||
|  |                 raise NotImplementedError  # can't return a vector | ||||||
|  |         elif isinstance(indices, slice): | ||||||
|  |             if len(shape) > 2: | ||||||
|  |                 return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) | ||||||
|  |             else: | ||||||
|  |                 return LoraTorchTensor(self._lora_A, self._lora_B[indices]) | ||||||
|  |         elif isinstance(indices, tuple): | ||||||
|  |             assert len(indices) > 0 | ||||||
|  |             if indices[-1] is Ellipsis: | ||||||
|  |                 return self[indices[:-1]] | ||||||
|  |             # expand ellipsis | ||||||
|  |             indices = tuple( | ||||||
|  |                 u | ||||||
|  |                 for v in ( | ||||||
|  |                     ( | ||||||
|  |                         (slice(None, None) for _ in range(len(indices) - 1)) | ||||||
|  |                         if i is Ellipsis | ||||||
|  |                         else (i,) | ||||||
|  |                     ) | ||||||
|  |                     for i in indices | ||||||
|  |                 ) | ||||||
|  |                 for u in v | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             if len(indices) < len(shape): | ||||||
|  |                 indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape)))) | ||||||
|  | 
 | ||||||
|  |             # TODO: make sure this is correct | ||||||
|  |             indices_A = ( | ||||||
|  |                 *( | ||||||
|  |                     ( | ||||||
|  |                         j.__index__() % self._lora_A.shape[i] | ||||||
|  |                         if isinstance(j, SupportsIndex) | ||||||
|  |                         else slice(None, None) | ||||||
|  |                     ) | ||||||
|  |                     for i, j in enumerate(indices[:-2]) | ||||||
|  |                 ), | ||||||
|  |                 slice(None, None), | ||||||
|  |                 indices[-1], | ||||||
|  |             ) | ||||||
|  |             indices_B = indices[:-1] | ||||||
|  |             return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B]) | ||||||
|  |         else: | ||||||
|  |             raise NotImplementedError  # unknown indice type | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def dtype(self) -> torch.dtype: | ||||||
|  |         assert self._lora_A.dtype == self._lora_B.dtype | ||||||
|  |         return self._lora_A.dtype | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def shape(self) -> tuple[int, ...]: | ||||||
|  |         assert len(self._lora_A.shape) == len(self._lora_B.shape) | ||||||
|  |         return (*self._lora_B.shape[:-1], self._lora_A.shape[-1]) | ||||||
|  | 
 | ||||||
|  |     def size(self, dim=None): | ||||||
|  |         assert dim is None | ||||||
|  |         return self.shape | ||||||
|  | 
 | ||||||
|  |     def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor: | ||||||
|  |         if isinstance(shape[0], tuple): | ||||||
|  |             new_shape: tuple[int, ...] = shape[0] | ||||||
|  |         else: | ||||||
|  |             new_shape = cast(tuple[int, ...], shape) | ||||||
|  |         orig_shape = self.shape | ||||||
|  |         if len(new_shape) < 2: | ||||||
|  |             raise NotImplementedError  # can't become a vector | ||||||
|  | 
 | ||||||
|  |         # expand -1 in the shape | ||||||
|  |         if any(dim == -1 for dim in new_shape): | ||||||
|  |             n_elems = prod(orig_shape) | ||||||
|  |             n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape) | ||||||
|  |             assert n_elems % n_new_elems == 0 | ||||||
|  |             new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),) | ||||||
|  | 
 | ||||||
|  |         if new_shape[-1] != orig_shape[-1]: | ||||||
|  |             raise NotImplementedError  # can't reshape the row size trivially | ||||||
|  | 
 | ||||||
|  |         shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1]) | ||||||
|  |         shape_B = (*new_shape[:-1], self._rank) | ||||||
|  |         return LoraTorchTensor( | ||||||
|  |             self._lora_A.reshape(shape_A), | ||||||
|  |             self._lora_B.reshape(shape_B), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def reshape_as(self, other: Tensor) -> LoraTorchTensor: | ||||||
|  |         return self.reshape(*other.shape) | ||||||
|  | 
 | ||||||
|  |     def view(self, *size: int) -> LoraTorchTensor: | ||||||
|  |         return self.reshape(*size) | ||||||
|  | 
 | ||||||
|  |     def permute(self, *dims: int) -> LoraTorchTensor: | ||||||
|  |         shape = self.shape | ||||||
|  |         dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims) | ||||||
|  |         if dims[-1] == -1: | ||||||
|  |             # TODO: support higher dimensional A shapes bigger than 1 | ||||||
|  |             assert all(dim == 1 for dim in self._lora_A.shape[:-2]) | ||||||
|  |             return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims)) | ||||||
|  |         if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1: | ||||||
|  |             return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) | ||||||
|  |         else: | ||||||
|  |             # TODO: compose the above two | ||||||
|  |             raise NotImplementedError | ||||||
|  | 
 | ||||||
|  |     def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor: | ||||||
|  |         shape = self.shape | ||||||
|  |         dims = [i for i in range(len(shape))] | ||||||
|  |         dims[dim0], dims[dim1] = dims[dim1], dims[dim0] | ||||||
|  |         return self.permute(*dims) | ||||||
|  | 
 | ||||||
|  |     def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: | ||||||
|  |         return self.transpose(axis0, axis1) | ||||||
|  | 
 | ||||||
|  |     def to(self, *args, **kwargs): | ||||||
|  |         return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): | ||||||
|  |         del types  # unused | ||||||
|  | 
 | ||||||
|  |         if kwargs is None: | ||||||
|  |             kwargs = {} | ||||||
|  | 
 | ||||||
|  |         if func is torch.permute: | ||||||
|  |             return type(args[0]).permute(*args, **kwargs) | ||||||
|  |         elif func is torch.reshape: | ||||||
|  |             return type(args[0]).reshape(*args, **kwargs) | ||||||
|  |         elif func is torch.stack: | ||||||
|  |             assert isinstance(args[0], Sequence) | ||||||
|  |             dim = kwargs.get("dim", 0) | ||||||
|  |             assert dim == 0 | ||||||
|  |             return LoraTorchTensor( | ||||||
|  |                 torch.stack([a._lora_A for a in args[0]], dim), | ||||||
|  |                 torch.stack([b._lora_B for b in args[0]], dim), | ||||||
|  |             ) | ||||||
|  |         elif func is torch.cat: | ||||||
|  |             assert isinstance(args[0], Sequence) | ||||||
|  |             dim = kwargs.get("dim", 0) | ||||||
|  |             assert dim == 0 | ||||||
|  |             if len(args[0][0].shape) > 2: | ||||||
|  |                 return LoraTorchTensor( | ||||||
|  |                     torch.cat([a._lora_A for a in args[0]], dim), | ||||||
|  |                     torch.cat([b._lora_B for b in args[0]], dim), | ||||||
|  |                 ) | ||||||
|  |             elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]): | ||||||
|  |                 return LoraTorchTensor( | ||||||
|  |                     args[0][0]._lora_A, | ||||||
|  |                     torch.cat([b._lora_B for b in args[0]], dim), | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 raise NotImplementedError | ||||||
|  |         else: | ||||||
|  |             raise NotImplementedError | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_base_tensor_name(lora_tensor_name: str) -> str: | ||||||
|  |     base_name = lora_tensor_name.replace("base_model.model.", "") | ||||||
|  |     base_name = base_name.replace(".lora_A.weight", ".weight") | ||||||
|  |     base_name = base_name.replace(".lora_B.weight", ".weight") | ||||||
|  |     return base_name | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def parse_args() -> argparse.Namespace: | ||||||
|  |     parser = argparse.ArgumentParser( | ||||||
|  |         description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--outfile", type=Path, | ||||||
|  |         help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", | ||||||
|  |         help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--bigendian", action="store_true", | ||||||
|  |         help="model is executed on big endian machine", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--no-lazy", action="store_true", | ||||||
|  |         help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--verbose", action="store_true", | ||||||
|  |         help="increase output verbosity", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--base", type=Path, required=True, | ||||||
|  |         help="directory containing base model file", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "lora_path", type=Path, | ||||||
|  |         help="directory containing LoRA adapter file", | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     return parser.parse_args() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     args = parse_args() | ||||||
|  |     logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) | ||||||
|  | 
 | ||||||
|  |     ftype_map: dict[str, gguf.LlamaFileType] = { | ||||||
|  |         "f32": gguf.LlamaFileType.ALL_F32, | ||||||
|  |         "f16": gguf.LlamaFileType.MOSTLY_F16, | ||||||
|  |         "bf16": gguf.LlamaFileType.MOSTLY_BF16, | ||||||
|  |         "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, | ||||||
|  |         "auto": gguf.LlamaFileType.GUESSED, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ftype = ftype_map[args.outtype] | ||||||
|  | 
 | ||||||
|  |     dir_base_model: Path = args.base | ||||||
|  |     dir_lora: Path = args.lora_path | ||||||
|  |     lora_config = dir_lora / "adapter_config.json" | ||||||
|  |     input_model = dir_lora / "adapter_model.safetensors" | ||||||
|  | 
 | ||||||
|  |     if args.outfile is not None: | ||||||
|  |         fname_out = args.outfile | ||||||
|  |     else: | ||||||
|  |         # output in the same directory as the model by default | ||||||
|  |         fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' | ||||||
|  | 
 | ||||||
|  |     if os.path.exists(input_model): | ||||||
|  |         # lazy import load_file only if lora is in safetensors format. | ||||||
|  |         from safetensors.torch import load_file | ||||||
|  | 
 | ||||||
|  |         lora_model = load_file(input_model, device="cpu") | ||||||
|  |     else: | ||||||
|  |         input_model = os.path.join(dir_lora, "adapter_model.bin") | ||||||
|  |         lora_model = torch.load(input_model, map_location="cpu", weights_only=True) | ||||||
|  | 
 | ||||||
|  |     # load base model | ||||||
|  |     logger.info(f"Loading base model: {dir_base_model.name}") | ||||||
|  |     hparams = Model.load_hparams(dir_base_model) | ||||||
|  |     with torch.inference_mode(): | ||||||
|  |         try: | ||||||
|  |             model_class = Model.from_model_architecture(hparams["architectures"][0]) | ||||||
|  |         except NotImplementedError: | ||||||
|  |             logger.error(f"Model {hparams['architectures'][0]} is not supported") | ||||||
|  |             sys.exit(1) | ||||||
|  | 
 | ||||||
|  |         class LoraModel(model_class): | ||||||
|  |             model_arch = model_class.model_arch | ||||||
|  | 
 | ||||||
|  |             def get_tensors(self) -> Iterator[tuple[str, Tensor]]: | ||||||
|  |                 tensor_map: dict[str, PartialLoraTensor] = {} | ||||||
|  | 
 | ||||||
|  |                 for name, tensor in lora_model.items(): | ||||||
|  |                     if self.lazy: | ||||||
|  |                         tensor = LazyTorchTensor.from_eager(tensor) | ||||||
|  |                     base_name = get_base_tensor_name(name) | ||||||
|  |                     is_lora_a = ".lora_A.weight" in name | ||||||
|  |                     is_lora_b = ".lora_B.weight" in name | ||||||
|  |                     if not is_lora_a and not is_lora_b: | ||||||
|  |                         if ".base_layer.weight" in name: | ||||||
|  |                             continue | ||||||
|  |                         logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") | ||||||
|  |                         sys.exit(1) | ||||||
|  | 
 | ||||||
|  |                     if base_name in tensor_map: | ||||||
|  |                         if is_lora_a: | ||||||
|  |                             tensor_map[base_name].A = tensor | ||||||
|  |                         else: | ||||||
|  |                             tensor_map[base_name].B = tensor | ||||||
|  |                     else: | ||||||
|  |                         if is_lora_a: | ||||||
|  |                             tensor_map[base_name] = PartialLoraTensor(A=tensor) | ||||||
|  |                         else: | ||||||
|  |                             tensor_map[base_name] = PartialLoraTensor(B=tensor) | ||||||
|  | 
 | ||||||
|  |                 for name, tensor in tensor_map.items(): | ||||||
|  |                     assert tensor.A is not None | ||||||
|  |                     assert tensor.B is not None | ||||||
|  |                     yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B))) | ||||||
|  | 
 | ||||||
|  |             def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||||
|  |                 dest = super().modify_tensors(data_torch, name, bid) | ||||||
|  |                 for dest_name, dest_data in dest: | ||||||
|  |                     assert isinstance(dest_data, LoraTorchTensor) | ||||||
|  |                     lora_a, lora_b = dest_data.get_lora_A_B() | ||||||
|  | 
 | ||||||
|  |                     yield (dest_name + ".lora_a", lora_a) | ||||||
|  |                     yield (dest_name + ".lora_b", lora_b) | ||||||
|  | 
 | ||||||
|  |         model_instance = LoraModel( | ||||||
|  |             dir_base_model, | ||||||
|  |             ftype, | ||||||
|  |             fname_out, | ||||||
|  |             is_big_endian=args.bigendian, | ||||||
|  |             use_temp_file=False, | ||||||
|  |             eager=args.no_lazy, | ||||||
|  |             model_name=None, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         with open(lora_config, "r") as f: | ||||||
|  |             lparams: dict[str, Any] = json.load(f) | ||||||
|  | 
 | ||||||
|  |         alpha = lparams["lora_alpha"] | ||||||
|  | 
 | ||||||
|  |         model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER) | ||||||
|  |         model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") | ||||||
|  |         model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha)) | ||||||
|  |         model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) | ||||||
|  |         logger.info("Exporting model...") | ||||||
|  |         model_instance.write() | ||||||
|  |         logger.info(f"Model successfully exported to {model_instance.fname_out}") | ||||||
|  | @ -1876,7 +1876,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor | ||||||
| 
 | 
 | ||||||
|     bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) |     bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) | ||||||
|         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 |         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 | ||||||
|         && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; |         && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2 | ||||||
|  |         && src1->ne[1] == 1; | ||||||
|     bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type) |     bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type) | ||||||
|         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 |         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 | ||||||
|         && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; |         && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; | ||||||
|  |  | ||||||
|  | @ -19478,7 +19478,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph | ||||||
| 
 | 
 | ||||||
|     fprintf(fp, "digraph G {\n"); |     fprintf(fp, "digraph G {\n"); | ||||||
|     fprintf(fp, "  newrank = true;\n"); |     fprintf(fp, "  newrank = true;\n"); | ||||||
|     fprintf(fp, "  rankdir = LR;\n"); |     fprintf(fp, "  rankdir = TB;\n"); | ||||||
| 
 | 
 | ||||||
|     for (int i = 0; i < gb->n_nodes; i++) { |     for (int i = 0; i < gb->n_nodes; i++) { | ||||||
|         struct ggml_tensor * node = gb->nodes[i]; |         struct ggml_tensor * node = gb->nodes[i]; | ||||||
|  | @ -19540,7 +19540,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); |         fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); | ||||||
|         if (ggml_nelements(node) < 5) { |         if (ggml_nelements(node) < 5 && node->data != NULL) { | ||||||
|             fprintf(fp, " | ("); |             fprintf(fp, " | ("); | ||||||
|             for (int j = 0; j < ggml_nelements(node); j++) { |             for (int j = 0; j < ggml_nelements(node); j++) { | ||||||
|                 if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { |                 if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { | ||||||
|  |  | ||||||
|  | @ -19,6 +19,7 @@ GGML_QUANT_VERSION     = 2  # GGML_QNT_VERSION from ggml.h | ||||||
| 
 | 
 | ||||||
| class Keys: | class Keys: | ||||||
|     class General: |     class General: | ||||||
|  |         TYPE                 = "general.type" | ||||||
|         ARCHITECTURE         = "general.architecture" |         ARCHITECTURE         = "general.architecture" | ||||||
|         QUANTIZATION_VERSION = "general.quantization_version" |         QUANTIZATION_VERSION = "general.quantization_version" | ||||||
|         ALIGNMENT            = "general.alignment" |         ALIGNMENT            = "general.alignment" | ||||||
|  | @ -120,11 +121,20 @@ class Keys: | ||||||
|         MIDDLE_ID            = "tokenizer.ggml.middle_token_id" |         MIDDLE_ID            = "tokenizer.ggml.middle_token_id" | ||||||
|         EOT_ID               = "tokenizer.ggml.eot_token_id" |         EOT_ID               = "tokenizer.ggml.eot_token_id" | ||||||
| 
 | 
 | ||||||
|  |     class Adapter: | ||||||
|  |         TYPE       = "adapter.type" | ||||||
|  |         LORA_ALPHA = "adapter.lora.alpha" | ||||||
|  | 
 | ||||||
| # | # | ||||||
| # recommended mapping of model tensor names for storage in gguf | # recommended mapping of model tensor names for storage in gguf | ||||||
| # | # | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class GGUFType: | ||||||
|  |     MODEL   = "model" | ||||||
|  |     ADAPTER = "adapter" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class MODEL_ARCH(IntEnum): | class MODEL_ARCH(IntEnum): | ||||||
|     LLAMA        = auto() |     LLAMA        = auto() | ||||||
|     FALCON       = auto() |     FALCON       = auto() | ||||||
|  |  | ||||||
|  | @ -424,6 +424,9 @@ class GGUFWriter: | ||||||
|                 fout.close() |                 fout.close() | ||||||
|             self.fout = None |             self.fout = None | ||||||
| 
 | 
 | ||||||
|  |     def add_type(self, type_name: str) -> None: | ||||||
|  |         self.add_string(Keys.General.TYPE, type_name) | ||||||
|  | 
 | ||||||
|     def add_architecture(self) -> None: |     def add_architecture(self) -> None: | ||||||
|         self.add_string(Keys.General.ARCHITECTURE, self.arch) |         self.add_string(Keys.General.ARCHITECTURE, self.arch) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -43,7 +43,7 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np. | ||||||
|         osize *= dim |         osize *= dim | ||||||
|     out = np.empty(shape=osize, dtype=otype) |     out = np.empty(shape=osize, dtype=otype) | ||||||
|     # compute over groups of 16 rows (arbitrary, but seems good for performance) |     # compute over groups of 16 rows (arbitrary, but seems good for performance) | ||||||
|     n_groups = rows.shape[0] // 16 |     n_groups = (rows.shape[0] // 16) or 1 | ||||||
|     np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out) |     np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out) | ||||||
|     return out.reshape(oshape) |     return out.reshape(oshape) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -411,6 +411,9 @@ extern "C" { | ||||||
|         const char * content; |         const char * content; | ||||||
|     } llama_chat_message; |     } llama_chat_message; | ||||||
| 
 | 
 | ||||||
|  |     // lora adapter
 | ||||||
|  |     struct llama_lora_adapter; | ||||||
|  | 
 | ||||||
|     // Helpers for getting default parameters
 |     // Helpers for getting default parameters
 | ||||||
|     LLAMA_API struct llama_model_params llama_model_default_params(void); |     LLAMA_API struct llama_model_params llama_model_default_params(void); | ||||||
|     LLAMA_API struct llama_context_params llama_context_default_params(void); |     LLAMA_API struct llama_context_params llama_context_default_params(void); | ||||||
|  | @ -510,18 +513,28 @@ extern "C" { | ||||||
|             const char * fname_out, |             const char * fname_out, | ||||||
|             const llama_model_quantize_params * params); |             const llama_model_quantize_params * params); | ||||||
| 
 | 
 | ||||||
|     // Apply a LoRA adapter to a loaded model
 |     // Load a LoRA adapter from file
 | ||||||
|     // path_base_model is the path to a higher quality model to use as a base for
 |     // The loaded adapter will be associated to the given model, and will be free when the model is deleted
 | ||||||
|     // the layers modified by the adapter. Can be NULL to use the current loaded model.
 |     LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init( | ||||||
|     // The model needs to be reloaded before applying a new adapter, otherwise the adapter
 |             struct llama_model * model, | ||||||
|     // will be applied on top of the previous one
 |             const char * path_lora); | ||||||
|     // Returns 0 on success
 | 
 | ||||||
|     LLAMA_API int32_t llama_model_apply_lora_from_file( |     // Add a loaded LoRA adapter to given context
 | ||||||
|             const struct llama_model * model, |     // This will not modify model's weight
 | ||||||
|                           const char * path_lora, |     LLAMA_API int32_t llama_lora_adapter_set( | ||||||
|                                float   scale, |             struct llama_context * ctx, | ||||||
|                           const char * path_base_model, |             struct llama_lora_adapter * adapter, | ||||||
|                              int32_t   n_threads); |             float scale); | ||||||
|  | 
 | ||||||
|  |     // Remove a LoRA adapter from given context
 | ||||||
|  |     // Return -1 if the adapter is not present in the context
 | ||||||
|  |     LLAMA_API int32_t llama_lora_adapter_remove( | ||||||
|  |             struct llama_context * ctx, | ||||||
|  |             struct llama_lora_adapter * adapter); | ||||||
|  | 
 | ||||||
|  |     // Manually free a LoRA adapter
 | ||||||
|  |     // Note: loaded adapters will be free when the associated model is deleted
 | ||||||
|  |     LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter); | ||||||
| 
 | 
 | ||||||
|     // Apply a loaded control vector to a llama_context, or if data is NULL, clear
 |     // Apply a loaded control vector to a llama_context, or if data is NULL, clear
 | ||||||
|     // the currently loaded vector.
 |     // the currently loaded vector.
 | ||||||
|  |  | ||||||
|  | @ -9,3 +9,4 @@ | ||||||
| -r ./requirements/requirements-convert_hf_to_gguf.txt | -r ./requirements/requirements-convert_hf_to_gguf.txt | ||||||
| -r ./requirements/requirements-convert_hf_to_gguf_update.txt | -r ./requirements/requirements-convert_hf_to_gguf_update.txt | ||||||
| -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt | -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt | ||||||
|  | -r ./requirements/requirements-convert_lora_to_gguf.txt | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								requirements/requirements-convert_lora_to_gguf.txt
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								requirements/requirements-convert_lora_to_gguf.txt
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,2 @@ | ||||||
|  | -r ./requirements-convert_hf_to_gguf.txt | ||||||
|  | --extra-index-url https://download.pytorch.org/whl/cpu | ||||||
							
								
								
									
										1010
									
								
								src/llama.cpp
									
										
									
									
									
								
							
							
						
						
									
										1010
									
								
								src/llama.cpp
									
										
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue