From 94f368fd536d303b75cfdc56c57e349d316aca9c Mon Sep 17 00:00:00 2001 From: qunash Date: Tue, 14 Mar 2023 01:50:50 +0300 Subject: [PATCH] Refactor get_n_parts function to simplify code and improve readability --- convert-pth-to-ggml.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d2557500a..e0899f248 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -37,17 +37,16 @@ fname_hparams = sys.argv[1] + "/params.json" fname_tokenizer = sys.argv[1] + "/../tokenizer.model" def get_n_parts(dim): - if dim == 4096: - return 1 - elif dim == 5120: - return 2 - elif dim == 6656: - return 4 - elif dim == 8192: - return 8 - else: - print("Invalid dim: " + str(dim)) + mappings = { + 4096: 1, + 5120: 2, + 6656: 4, + 8192: 8 + } + if dim not in mappings: + print(f"Invalid dim: {dim}") sys.exit(1) + return mappings[dim] # possible data types # ftype == 0 -> float32