Refactor get_n_parts function to simplify code and improve readability

This commit is contained in:
qunash 2023-03-14 01:50:50 +03:00
parent 2f700a2738
commit 94f368fd53

View file

@ -37,17 +37,16 @@ fname_hparams = sys.argv[1] + "/params.json"
fname_tokenizer = sys.argv[1] + "/../tokenizer.model" fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
def get_n_parts(dim): def get_n_parts(dim):
if dim == 4096: mappings = {
return 1 4096: 1,
elif dim == 5120: 5120: 2,
return 2 6656: 4,
elif dim == 6656: 8192: 8
return 4 }
elif dim == 8192: if dim not in mappings:
return 8 print(f"Invalid dim: {dim}")
else:
print("Invalid dim: " + str(dim))
sys.exit(1) sys.exit(1)
return mappings[dim]
# possible data types # possible data types
# ftype == 0 -> float32 # ftype == 0 -> float32