diff --git a/merge.py b/merge-HF-and-lora-to-HF.py similarity index 50% rename from merge.py rename to merge-HF-and-lora-to-HF.py index f881802a4..bd3db4609 100644 --- a/merge.py +++ b/merge-HF-and-lora-to-HF.py @@ -1,31 +1,31 @@ -import os, time -import tempfile import json import torch import argparse -import transformers from transformers import LlamaTokenizer, LlamaForCausalLM from peft import PeftModel -# args -parser = argparse.ArgumentParser() +# args with description. +parser = argparse.ArgumentParser( + prog="Merge HF file with Lora\n", + description="Please locate HF format model path with pytorch_*.bin inside, lora path with adapter_config.json and adapter_model.bin.", +) + # The original base model checkpoint dir -parser.add_argument("--model_path", type=str, default='llama-7b-hf') +parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf", help="Directory contain original HF model") # The finetuned lora model checkpoint dir -parser.add_argument("--lora_path",type=str, default='lora') +parser.add_argument("--lora_path", type=str, default="decapoda-research/lora", help="Directory contain Lora ") # The output dir -parser.add_argument("--out_path", type=str, default='lora-merged') +parser.add_argument("--out_path", type=str, default="decapoda-research/lora-merged", help="Directory store merged HF model") args = parser.parse_args() - - print(f">>> load model from {args.model_path} and lora from {args.lora_path}....") +# transformer loaded. load and save Tokenizer. tokenizer = LlamaTokenizer.from_pretrained(args.model_path) +tokenizer.save_pretrained(args.out_path) -#transformer loaded. load model. - +# load model. model = LlamaForCausalLM.from_pretrained( args.model_path, load_in_8bit=False, @@ -34,7 +34,7 @@ model = LlamaForCausalLM.from_pretrained( ) -#peft loaded. load lora. +# peft loaded. load lora. model = PeftModel.from_pretrained( model, args.lora_path, @@ -44,7 +44,6 @@ model = PeftModel.from_pretrained( print(f">>> merging lora...") -#Using new Peft function merge Lora +# Using Peft function to merge Lora. model = model.merge_and_unload() model.save_pretrained(args.out_path) -