Refactor get_n_parts function to simplify code and improve readability
This commit is contained in:
parent
2f700a2738
commit
94f368fd53
1 changed files with 9 additions and 10 deletions
|
@ -37,17 +37,16 @@ fname_hparams = sys.argv[1] + "/params.json"
|
||||||
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
|
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
|
||||||
|
|
||||||
def get_n_parts(dim):
|
def get_n_parts(dim):
|
||||||
if dim == 4096:
|
mappings = {
|
||||||
return 1
|
4096: 1,
|
||||||
elif dim == 5120:
|
5120: 2,
|
||||||
return 2
|
6656: 4,
|
||||||
elif dim == 6656:
|
8192: 8
|
||||||
return 4
|
}
|
||||||
elif dim == 8192:
|
if dim not in mappings:
|
||||||
return 8
|
print(f"Invalid dim: {dim}")
|
||||||
else:
|
|
||||||
print("Invalid dim: " + str(dim))
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
return mappings[dim]
|
||||||
|
|
||||||
# possible data types
|
# possible data types
|
||||||
# ftype == 0 -> float32
|
# ftype == 0 -> float32
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue