Minor types cleanups.

This commit is contained in:
KerfuffleV2 2023-08-27 11:16:46 -06:00
parent 084dd216cd
commit 795c0c6e9d
2 changed files with 4 additions and 4 deletions

View file

@ -113,7 +113,7 @@ gguf_writer.add_file_type(ftype)
print("gguf: get tokenizer metadata") print("gguf: get tokenizer metadata")
tokens: List[str] = [] tokens: List[bytearray] = []
scores: List[float] = [] scores: List[float] = []
toktypes: List[int] = [] toktypes: List[int] = []
merges: List[str] = [] merges: List[str] = []
@ -199,7 +199,7 @@ head_dim = hparams["hidden_size"] // n_head
print("gguf: get tensor metadata") print("gguf: get tensor metadata")
if num_parts == 0: if num_parts == 0:
part_names = ("pytorch_model.bin",) part_names = iter(("pytorch_model.bin",))
else: else:
part_names = ( part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)

View file

@ -9,7 +9,7 @@ import json
import numpy as np import numpy as np
import torch import torch
from typing import Any, List, Optional from typing import Any, List, Optional, TypeAlias
from pathlib import Path from pathlib import Path
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
@ -254,7 +254,7 @@ tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
print("gguf: get tensor metadata") print("gguf: get tensor metadata")
if num_parts == 0: if num_parts == 0:
part_names = ("pytorch_model.bin",) part_names = iter(("pytorch_model.bin",))
else: else:
part_names = ( part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)