Update and rename merge.py to merge-HF-and-lora-to-HF.py
This commit is contained in:
parent
3c6bdad892
commit
e970d41095
1 changed files with 14 additions and 15 deletions
|
@ -1,31 +1,31 @@
|
|||
import os, time
|
||||
import tempfile
|
||||
import json
|
||||
import torch
|
||||
import argparse
|
||||
import transformers
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
from peft import PeftModel
|
||||
|
||||
# args
|
||||
parser = argparse.ArgumentParser()
|
||||
# args with description.
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="Merge HF file with Lora\n",
|
||||
description="Please locate HF format model path with pytorch_*.bin inside, lora path with adapter_config.json and adapter_model.bin.",
|
||||
)
|
||||
|
||||
# The original base model checkpoint dir
|
||||
parser.add_argument("--model_path", type=str, default='llama-7b-hf')
|
||||
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf", help="Directory contain original HF model")
|
||||
# The finetuned lora model checkpoint dir
|
||||
parser.add_argument("--lora_path",type=str, default='lora')
|
||||
parser.add_argument("--lora_path", type=str, default="decapoda-research/lora", help="Directory contain Lora ")
|
||||
# The output dir
|
||||
parser.add_argument("--out_path", type=str, default='lora-merged')
|
||||
parser.add_argument("--out_path", type=str, default="decapoda-research/lora-merged", help="Directory store merged HF model")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
||||
print(f">>> load model from {args.model_path} and lora from {args.lora_path}....")
|
||||
|
||||
# transformer loaded. load and save Tokenizer.
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.model_path)
|
||||
tokenizer.save_pretrained(args.out_path)
|
||||
|
||||
#transformer loaded. load model.
|
||||
|
||||
# load model.
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
args.model_path,
|
||||
load_in_8bit=False,
|
||||
|
@ -34,7 +34,7 @@ model = LlamaForCausalLM.from_pretrained(
|
|||
)
|
||||
|
||||
|
||||
#peft loaded. load lora.
|
||||
# peft loaded. load lora.
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
args.lora_path,
|
||||
|
@ -44,7 +44,6 @@ model = PeftModel.from_pretrained(
|
|||
|
||||
print(f">>> merging lora...")
|
||||
|
||||
#Using new Peft function merge Lora
|
||||
# Using Peft function to merge Lora.
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained(args.out_path)
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue