From 3d19dd44b60563cf9b8b68cdf76ae553af32cd3d Mon Sep 17 00:00:00 2001 From: HimariO Date: Mon, 21 Oct 2024 02:28:19 +0800 Subject: [PATCH] add arg parser to qwen2vl_surgery --- examples/llava/qwen2_vl_surgery.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index f873e8cab..0c8bb3ed0 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -73,7 +73,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: return tensor_map -def main(data_type='fp32'): +def main(args, data_type='fp32'): if data_type == 'fp32': dtype = torch.float32 np_dtype = np.float32 @@ -85,7 +85,8 @@ def main(data_type='fp32'): else: raise ValueError() - model_name = "Qwen/Qwen2-VL-2B-Instruct" + model_name = args.model_name + print("model_name: ", model_name) qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( model_name, torch_dtype=dtype, device_map="cpu" ) @@ -140,4 +141,8 @@ def main(data_type='fp32'): fout.close() -main() \ No newline at end of file +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") + args = parser.parse_args() + main(args) \ No newline at end of file