Refactor get_n_parts function to simplify code and improve readability
This commit is contained in:
parent
2f700a2738
commit
94f368fd53
1 changed files with 9 additions and 10 deletions
|
@ -37,17 +37,16 @@ fname_hparams = sys.argv[1] + "/params.json"
|
|||
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
|
||||
|
||||
def get_n_parts(dim):
|
||||
if dim == 4096:
|
||||
return 1
|
||||
elif dim == 5120:
|
||||
return 2
|
||||
elif dim == 6656:
|
||||
return 4
|
||||
elif dim == 8192:
|
||||
return 8
|
||||
else:
|
||||
print("Invalid dim: " + str(dim))
|
||||
mappings = {
|
||||
4096: 1,
|
||||
5120: 2,
|
||||
6656: 4,
|
||||
8192: 8
|
||||
}
|
||||
if dim not in mappings:
|
||||
print(f"Invalid dim: {dim}")
|
||||
sys.exit(1)
|
||||
return mappings[dim]
|
||||
|
||||
# possible data types
|
||||
# ftype == 0 -> float32
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue