add arg parser to qwen2vl_surgery
This commit is contained in:
parent
023f0076e0
commit
3d19dd44b6
1 changed files with 8 additions and 3 deletions
|
@ -73,7 +73,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
||||||
return tensor_map
|
return tensor_map
|
||||||
|
|
||||||
|
|
||||||
def main(data_type='fp32'):
|
def main(args, data_type='fp32'):
|
||||||
if data_type == 'fp32':
|
if data_type == 'fp32':
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
np_dtype = np.float32
|
np_dtype = np.float32
|
||||||
|
@ -85,7 +85,8 @@ def main(data_type='fp32'):
|
||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
model_name = "Qwen/Qwen2-VL-2B-Instruct"
|
model_name = args.model_name
|
||||||
|
print("model_name: ", model_name)
|
||||||
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
model_name, torch_dtype=dtype, device_map="cpu"
|
model_name, torch_dtype=dtype, device_map="cpu"
|
||||||
)
|
)
|
||||||
|
@ -140,4 +141,8 @@ def main(data_type='fp32'):
|
||||||
fout.close()
|
fout.close()
|
||||||
|
|
||||||
|
|
||||||
main()
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
Loading…
Add table
Add a link
Reference in a new issue