Refactor get_n_parts function to simplify code and improve readability

This commit is contained in:
qunash 2023-03-14 01:50:50 +03:00
parent 2f700a2738
commit 94f368fd53

View file

@ -37,17 +37,16 @@ fname_hparams = sys.argv[1] + "/params.json"
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
def get_n_parts(dim):
if dim == 4096:
return 1
elif dim == 5120:
return 2
elif dim == 6656:
return 4
elif dim == 8192:
return 8
else:
print("Invalid dim: " + str(dim))
mappings = {
4096: 1,
5120: 2,
6656: 4,
8192: 8
}
if dim not in mappings:
print(f"Invalid dim: {dim}")
sys.exit(1)
return mappings[dim]
# possible data types
# ftype == 0 -> float32