convert-lora : make --base
optional (#10110)
* convert-lora : make `--base` optional * lint * handle case where base_model_name_or_path is invalid * do not include metadata from base model * clarify unspecified --base * add small comment [no ci] * trigger ci
This commit is contained in:
parent
a6744e43e8
commit
7554aa4655
2 changed files with 51 additions and 23 deletions
|
@ -12,6 +12,7 @@ import json
|
|||
from math import prod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
|
||||
from transformers import AutoConfig
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -256,8 +257,8 @@ def parse_args() -> argparse.Namespace:
|
|||
help="only print out what will be done, without writing any new files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base", type=Path, required=True,
|
||||
help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required",
|
||||
"--base", type=Path,
|
||||
help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"lora_path", type=Path,
|
||||
|
@ -267,6 +268,12 @@ def parse_args() -> argparse.Namespace:
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
|
||||
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
||||
config = AutoConfig.from_pretrained(hf_model_id)
|
||||
return config.to_dict()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
@ -281,7 +288,7 @@ if __name__ == '__main__':
|
|||
|
||||
ftype = ftype_map[args.outtype]
|
||||
|
||||
dir_base_model: Path = args.base
|
||||
dir_base_model: Path | None = args.base
|
||||
dir_lora: Path = args.lora_path
|
||||
lora_config = dir_lora / "adapter_config.json"
|
||||
input_model = dir_lora / "adapter_model.safetensors"
|
||||
|
@ -301,9 +308,29 @@ if __name__ == '__main__':
|
|||
input_model = os.path.join(dir_lora, "adapter_model.bin")
|
||||
lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
|
||||
|
||||
# load LoRA config
|
||||
with open(lora_config, "r") as f:
|
||||
lparams: dict[str, Any] = json.load(f)
|
||||
|
||||
# load base model
|
||||
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||
hparams = Model.load_hparams(dir_base_model)
|
||||
if dir_base_model is None:
|
||||
if "base_model_name_or_path" in lparams:
|
||||
model_id = lparams["base_model_name_or_path"]
|
||||
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
||||
try:
|
||||
hparams = load_hparams_from_hf(model_id)
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to load base model config: {e}")
|
||||
logger.error("Please try downloading the base model and add its path to --base")
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.error("'base_model_name_or_path' is not found in adapter_config.json")
|
||||
logger.error("Base model config is required. Please download the base model and add its path to --base")
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||
hparams = Model.load_hparams(dir_base_model)
|
||||
|
||||
with torch.inference_mode():
|
||||
try:
|
||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||
|
@ -323,13 +350,15 @@ if __name__ == '__main__':
|
|||
self.dir_model_card = dir_lora_model
|
||||
self.lora_alpha = float(lora_alpha)
|
||||
|
||||
def set_vocab(self):
|
||||
pass
|
||||
|
||||
def set_type(self):
|
||||
self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
|
||||
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
|
||||
super().set_gguf_parameters()
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters
|
||||
|
@ -350,7 +379,7 @@ if __name__ == '__main__':
|
|||
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
||||
if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
|
||||
logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
|
||||
logger.error("Hint: if you are using TRL, make sure not to call setup_chat_format()")
|
||||
logger.error("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948")
|
||||
sys.exit(1)
|
||||
|
||||
if base_name in tensor_map:
|
||||
|
@ -384,9 +413,6 @@ if __name__ == '__main__':
|
|||
yield (dest_name + ".lora_a", lora_a)
|
||||
yield (dest_name + ".lora_b", lora_b)
|
||||
|
||||
with open(lora_config, "r") as f:
|
||||
lparams: dict[str, Any] = json.load(f)
|
||||
|
||||
alpha: float = lparams["lora_alpha"]
|
||||
|
||||
model_instance = LoraModel(
|
||||
|
@ -399,6 +425,7 @@ if __name__ == '__main__':
|
|||
dry_run=args.dry_run,
|
||||
dir_lora_model=dir_lora,
|
||||
lora_alpha=alpha,
|
||||
hparams=hparams,
|
||||
)
|
||||
|
||||
logger.info("Exporting model...")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue