diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index f873e8cab..0c8bb3ed0 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -73,7 +73,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: return tensor_map -def main(data_type='fp32'): +def main(args, data_type='fp32'): if data_type == 'fp32': dtype = torch.float32 np_dtype = np.float32 @@ -85,7 +85,8 @@ def main(data_type='fp32'): else: raise ValueError() - model_name = "Qwen/Qwen2-VL-2B-Instruct" + model_name = args.model_name + print("model_name: ", model_name) qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( model_name, torch_dtype=dtype, device_map="cpu" ) @@ -140,4 +141,8 @@ def main(data_type='fp32'): fout.close() -main() \ No newline at end of file +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") + args = parser.parse_args() + main(args) \ No newline at end of file