Update and rename merge.py to merge-HF-and-lora-to-HF.py

This commit is contained in:
FNsi 2023-05-21 11:09:08 +08:00 committed by GitHub
parent 3c6bdad892
commit e970d41095
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,31 +1,31 @@
import os, time
import tempfile
import json import json
import torch import torch
import argparse import argparse
import transformers
from transformers import LlamaTokenizer, LlamaForCausalLM from transformers import LlamaTokenizer, LlamaForCausalLM
from peft import PeftModel from peft import PeftModel
# args # args with description.
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(
prog="Merge HF file with Lora\n",
description="Please locate HF format model path with pytorch_*.bin inside, lora path with adapter_config.json and adapter_model.bin.",
)
# The original base model checkpoint dir # The original base model checkpoint dir
parser.add_argument("--model_path", type=str, default='llama-7b-hf') parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf", help="Directory contain original HF model")
# The finetuned lora model checkpoint dir # The finetuned lora model checkpoint dir
parser.add_argument("--lora_path",type=str, default='lora') parser.add_argument("--lora_path", type=str, default="decapoda-research/lora", help="Directory contain Lora ")
# The output dir # The output dir
parser.add_argument("--out_path", type=str, default='lora-merged') parser.add_argument("--out_path", type=str, default="decapoda-research/lora-merged", help="Directory store merged HF model")
args = parser.parse_args() args = parser.parse_args()
print(f">>> load model from {args.model_path} and lora from {args.lora_path}....") print(f">>> load model from {args.model_path} and lora from {args.lora_path}....")
# transformer loaded. load and save Tokenizer.
tokenizer = LlamaTokenizer.from_pretrained(args.model_path) tokenizer = LlamaTokenizer.from_pretrained(args.model_path)
tokenizer.save_pretrained(args.out_path)
#transformer loaded. load model. # load model.
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
args.model_path, args.model_path,
load_in_8bit=False, load_in_8bit=False,
@ -44,7 +44,6 @@ model = PeftModel.from_pretrained(
print(f">>> merging lora...") print(f">>> merging lora...")
#Using new Peft function merge Lora # Using Peft function to merge Lora.
model = model.merge_and_unload() model = model.merge_and_unload()
model.save_pretrained(args.out_path) model.save_pretrained(args.out_path)