diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d2557500a..e0899f248 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -37,17 +37,16 @@ fname_hparams = sys.argv[1] + "/params.json" fname_tokenizer = sys.argv[1] + "/../tokenizer.model" def get_n_parts(dim): - if dim == 4096: - return 1 - elif dim == 5120: - return 2 - elif dim == 6656: - return 4 - elif dim == 8192: - return 8 - else: - print("Invalid dim: " + str(dim)) + mappings = { + 4096: 1, + 5120: 2, + 6656: 4, + 8192: 8 + } + if dim not in mappings: + print(f"Invalid dim: {dim}") sys.exit(1) + return mappings[dim] # possible data types # ftype == 0 -> float32