add arg parser to qwen2vl_surgery
This commit is contained in:
parent
023f0076e0
commit
3d19dd44b6
1 changed files with 8 additions and 3 deletions
|
@ -73,7 +73,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
|||
return tensor_map
|
||||
|
||||
|
||||
def main(data_type='fp32'):
|
||||
def main(args, data_type='fp32'):
|
||||
if data_type == 'fp32':
|
||||
dtype = torch.float32
|
||||
np_dtype = np.float32
|
||||
|
@ -85,7 +85,8 @@ def main(data_type='fp32'):
|
|||
else:
|
||||
raise ValueError()
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
model_name = args.model_name
|
||||
print("model_name: ", model_name)
|
||||
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="cpu"
|
||||
)
|
||||
|
@ -140,4 +141,8 @@ def main(data_type='fp32'):
|
|||
fout.close()
|
||||
|
||||
|
||||
main()
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Loading…
Add table
Add a link
Reference in a new issue