refactor: Revise check_vocab_size for Enhanced Clarity and Correctness

- Resolved an unreachable branch issue by reorganizing the conditional structure.
- Moved the special case check for `params.n_vocab == -1` to the top for immediate assertion.
- Flattened the conditional logic for improved clarity and predictability of the function's behavior.

These changes enhance the readability and functional correctness of the `check_vocab_size` function without altering its intended functionality.
This commit is contained in:
teleprint-me 2024-01-09 13:30:35 -05:00
parent dd1c1004f8
commit 787860ada2
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -988,34 +988,33 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
if params.n_vocab != vocab.vocab_size: # Handle special case where the model's vocab size is not set
if params.n_vocab == vocab.vocab_size: if params.n_vocab == -1:
print( raise ValueError(
"Ignoring added_tokens.json since model matches vocab size without it." "The model's vocab size is set to -1 in params.json. Please update it manually."
) )
# Check for a vocab size mismatch
if params.n_vocab == vocab.vocab_size:
print("Ignoring added_tokens.json since model matches vocab size without it.")
return return
if pad_vocab and params.n_vocab > vocab.vocab_size: if pad_vocab and params.n_vocab > vocab.vocab_size:
pad_count = params.n_vocab - vocab.vocab_size pad_count = params.n_vocab - vocab.vocab_size
print( print(
f"Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>" f"Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>"
) )
for i in range(1, (params.n_vocab - vocab.vocab_size) + 1): for i in range(1, pad_count + 1):
vocab.added_tokens_dict[f"<dummy{i:05}>"] = -1 vocab.added_tokens_dict[f"<dummy{i:05}>"] = -1
vocab.vocab_size = params.n_vocab vocab.vocab_size = params.n_vocab
return return
msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}"
msg += f" has {vocab.vocab_size})." msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer} has {vocab.vocab_size})."
if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20: if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
if vocab.vocab_size < params.n_vocab: if vocab.vocab_size < params.n_vocab:
msg += " Add the --pad-vocab option and try again." msg += " Add the --pad-vocab option and try again."
# Check if params.n_vocab is -1 and issue a warning
if params.n_vocab == -1:
warnings.warn(
"WARNING: The model's vocab size is set to -1 in params.json. Please update it manually."
)
raise Exception(msg) raise Exception(msg)