merge-hf-and-lora-to-hf.py

This commit is contained in:
FNsi 2023-05-22 19:29:24 +08:00 committed by GitHub
parent 1fd5d10b07
commit 29995194e3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -32,11 +32,11 @@ parser.add_argument(
args = parser.parse_args() args = parser.parse_args()
print(f">>> load model from {args.model_path} and lora from {args.lora_path}....") print(f">>> load model from {args.model} and lora from {args.lora}....")
# transformer loaded. load and save Tokenizer. # transformer loaded. load and save Tokenizer.
tokenizer = LlamaTokenizer.from_pretrained(args.model_path) tokenizer = LlamaTokenizer.from_pretrained(args.model)
tokenizer.save_pretrained(args.out_path) tokenizer.save_pretrained(args.out)
# load model. # load model.
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
@ -49,7 +49,7 @@ model = LlamaForCausalLM.from_pretrained(
# peft loaded. load lora. # peft loaded. load lora.
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
args.lora_path, args.lora,
torch_dtype=torch.float16, torch_dtype=torch.float16,
device_map={"": "cpu"}, device_map={"": "cpu"},
) )
@ -58,4 +58,4 @@ print(f">>> merging lora...")
# Using Peft function to merge Lora. # Using Peft function to merge Lora.
model = model.merge_and_unload() model = model.merge_and_unload()
model.save_pretrained(args.out_path) model.save_pretrained(args.out)