This commit is contained in:
Mana 2024-07-03 15:11:12 +02:00 committed by GitHub
commit c2e436e3a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,6 +7,7 @@ import concurrent.futures
import enum
import faulthandler
import functools
import hashlib
import itertools
import json
import math
@ -24,7 +25,9 @@ from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar, Optional
from transformers import PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import TikTokenConverter
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable, Optional
import numpy as np
@ -51,6 +54,7 @@ DEFAULT_CONCURRENCY = 8
ADDED_TOKENS_FILE = 'added_tokens.json'
FAST_TOKENIZER_FILE = 'tokenizer.json'
is_llama3_model = False
#
# data types
@ -523,6 +527,9 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
else:
# split by rows
axis = 0
global is_llama3_model
if name.startswith('tok_embeddings.') and is_llama3_model:
axis = 0
concatenated_shape = list(lazy_tensors[0].shape)
concatenated_shape[axis] = sum(tensor.shape[axis] for tensor in lazy_tensors)
@ -896,6 +903,12 @@ class OutputFile:
tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab)
# Add extracted token information for model conversion
# Tokenizer for LLaMA 3
# Source: trust me bro
global is_llama3_model
if is_llama3_model:
self.gguf.add_tokenizer_model("gpt2")
self.gguf.add_tokenizer_pre("llama-bpe")
self.gguf.add_token_list(tokens)
self.gguf.add_token_scores(scores)
self.gguf.add_token_types(toktypes)
@ -1208,7 +1221,7 @@ class VocabFactory:
try:
vocab = cls(self.path)
break
except FileNotFoundError:
except:
pass # ignore unavailable tokenizers
else:
raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}")
@ -1274,6 +1287,57 @@ def do_dump_model(model_plus: ModelPlus) -> None:
for name, lazy_tensor in model_plus.model.items():
print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") # noqa: NP100
# Tokenizer conversion for LLaMA 3
# Credits: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
class Llama3Converter(TikTokenConverter):
def __init__(self, vocab_file, num_reserved_special_tokens=256, **kwargs):
super().__init__(vocab_file, **kwargs)
tokenizer = self.converted()
chat_template = (
"{% set loop_messages = messages %}"
"{% for message in loop_messages %}"
"{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
"{% if loop.index0 == 0 %}"
"{% set content = bos_token + content %}"
"{% endif %}"
"{{ content }}"
"{% endfor %}"
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
)
num_reserved_special_tokens = 256
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)]
tokenizer.add_special_tokens(special_tokens)
self.tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
bos_token="<|begin_of_text|>",
eos_token="<|end_of_text|>",
chat_template=chat_template,
model_input_names=["input_ids", "attention_mask"],
)
def write_llama3_tokenizer(tokenizer_path, input_tokenizer_path):
tokenizer = Llama3Converter(input_tokenizer_path).tokenizer
print(f"Saving a {tokenizer.__class__.__name__} to {tokenizer_path}.")
tokenizer.save_pretrained(tokenizer_path)
return tokenizer
def is_llama3_tokenizer(tokenizer_path) -> bool:
llama3_tokenizer_model_hash : str = "82e9d31979e92ab929cd544440f129d9ecd797b69e327f80f17e1c50d5551b55"
with open(tokenizer_path, "rb") as f:
tokenizer_hash = hashlib.sha256(f.read()).hexdigest()
return llama3_tokenizer_model_hash == tokenizer_hash
def main(args_in: list[str] | None = None) -> None:
output_choices = ["f32", "f16"]
@ -1287,7 +1351,7 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab")
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft")
parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft,bpe)", default="spm,hfft,bpe")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
@ -1298,7 +1362,6 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")
parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name")
args = parser.parse_args(args_in)
if args.verbose:
@ -1311,6 +1374,12 @@ def main(args_in: list[str] | None = None) -> None:
metadata = Metadata.load(args.metadata)
#TODO: add more bandaids for llama 3 detection
if is_llama3_tokenizer(os.path.join(args.model, "tokenizer.model")):
global is_llama3_model
write_llama3_tokenizer(args.model, os.path.join(args.model, "tokenizer.model"))
is_llama3_model = True
if args.get_outfile:
model_plus = load_some_model(args.model)
params = Params.load(model_plus)
@ -1366,6 +1435,7 @@ def main(args_in: list[str] | None = None) -> None:
logger.info(f"params = {params}")
model_parent_path = model_plus.paths[0].parent
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
vocab_factory = VocabFactory(vocab_path)