From 36bff51a7ae4ea5bc5eaf6f49460699692ad4b54 Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Fri, 3 May 2024 10:06:36 +0800 Subject: [PATCH] fix tokenizer.json tokenizer_config.json cpu() --- examples/minicpmv/minicpm-surgery.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/minicpmv/minicpm-surgery.py b/examples/minicpmv/minicpm-surgery.py index 97a02c6ca..85b498c97 100644 --- a/examples/minicpmv/minicpm-surgery.py +++ b/examples/minicpmv/minicpm-surgery.py @@ -1,6 +1,6 @@ import argparse import glob -import os +import os, json import torch from transformers import AutoModel, AutoTokenizer @@ -16,12 +16,12 @@ checkpoint = model.state_dict() mm_tensors = [k for k, v in checkpoint.items() if k.startswith("resampler")] # store these tensors in a new dictionary and torch.save them -projector = {name: checkpoint[name].float() for name in mm_tensors} +projector = {name: checkpoint[name].float().cpu() for name in mm_tensors} torch.save(projector, f"{args.model}/llava.projector") clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")] if len(clip_tensors) > 0: - clip = {name.replace("vpm.", ""): checkpoint[name].float() for name in clip_tensors} + clip = {name.replace("vpm.", ""): checkpoint[name].float().cpu() for name in clip_tensors} torch.save(clip, f"{args.model}/llava.clip") # added tokens should be removed to be able to convert Mistral models @@ -42,6 +42,15 @@ model.llm.save_pretrained(f"{args.model}/MiniCPM") tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) tok.save_pretrained(f"{args.model}/MiniCPM") os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/MiniCPM/modeling_minicpm.py") +os.system(f"cp {args.model}/tokenizer.json {args.model}/MiniCPM/tokenizer.json") +with open(f"{args.model}/MiniCPM/tokenizer_config.json", "r") as f: + d = json.load(f) + d.pop("auto_map") + d["tokenizer_class"] = "LlamaTokenizer" + d.pop("add_prefix_space") +with open(f"{args.model}/MiniCPM/tokenizer_config.json", "w") as f: + json.dump(d, f, indent=2) + print("Done!") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")