Update and rename merge.py to merge-HF-and-lora-to-HF.py

This commit is contained in:
FNsi 2023-05-21 11:09:08 +08:00 committed by GitHub
parent 3c6bdad892
commit e970d41095
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,31 +1,31 @@
import os, time
import tempfile
import json
import torch
import argparse
import transformers
from transformers import LlamaTokenizer, LlamaForCausalLM
from peft import PeftModel
# args
parser = argparse.ArgumentParser()
# args with description.
parser = argparse.ArgumentParser(
prog="Merge HF file with Lora\n",
description="Please locate HF format model path with pytorch_*.bin inside, lora path with adapter_config.json and adapter_model.bin.",
)
# The original base model checkpoint dir
parser.add_argument("--model_path", type=str, default='llama-7b-hf')
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf", help="Directory contain original HF model")
# The finetuned lora model checkpoint dir
parser.add_argument("--lora_path",type=str, default='lora')
parser.add_argument("--lora_path", type=str, default="decapoda-research/lora", help="Directory contain Lora ")
# The output dir
parser.add_argument("--out_path", type=str, default='lora-merged')
parser.add_argument("--out_path", type=str, default="decapoda-research/lora-merged", help="Directory store merged HF model")
args = parser.parse_args()
print(f">>> load model from {args.model_path} and lora from {args.lora_path}....")
# transformer loaded. load and save Tokenizer.
tokenizer = LlamaTokenizer.from_pretrained(args.model_path)
tokenizer.save_pretrained(args.out_path)
#transformer loaded. load model.
# load model.
model = LlamaForCausalLM.from_pretrained(
args.model_path,
load_in_8bit=False,
@ -34,7 +34,7 @@ model = LlamaForCausalLM.from_pretrained(
)
#peft loaded. load lora.
# peft loaded. load lora.
model = PeftModel.from_pretrained(
model,
args.lora_path,
@ -44,7 +44,6 @@ model = PeftModel.from_pretrained(
print(f">>> merging lora...")
#Using new Peft function merge Lora
# Using Peft function to merge Lora.
model = model.merge_and_unload()
model.save_pretrained(args.out_path)