add arg parser to qwen2vl_surgery

This commit is contained in:
HimariO 2024-10-21 02:28:19 +08:00
parent 023f0076e0
commit 3d19dd44b6

View file

@ -73,7 +73,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
return tensor_map return tensor_map
def main(data_type='fp32'): def main(args, data_type='fp32'):
if data_type == 'fp32': if data_type == 'fp32':
dtype = torch.float32 dtype = torch.float32
np_dtype = np.float32 np_dtype = np.float32
@ -85,7 +85,8 @@ def main(data_type='fp32'):
else: else:
raise ValueError() raise ValueError()
model_name = "Qwen/Qwen2-VL-2B-Instruct" model_name = args.model_name
print("model_name: ", model_name)
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype, device_map="cpu" model_name, torch_dtype=dtype, device_map="cpu"
) )
@ -140,4 +141,8 @@ def main(data_type='fp32'):
fout.close() fout.close()
main() if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
args = parser.parse_args()
main(args)