Merge remote-tracking branch 'origin/master' into json-additional
This commit is contained in:
commit
35bbac1e45
25 changed files with 2473 additions and 2569 deletions
6
.github/workflows/docker.yml
vendored
6
.github/workflows/docker.yml
vendored
|
@ -33,15 +33,13 @@ jobs:
|
|||
- { tag: "light", dockerfile: ".devops/llama-cli.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "server", dockerfile: ".devops/llama-server.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "full", dockerfile: ".devops/full.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
# NOTE(canardletter): The CUDA builds on arm64 are very slow, so I
|
||||
# have disabled them for now until the reason why
|
||||
# is understood.
|
||||
- { tag: "light-cuda", dockerfile: ".devops/llama-cli-cuda.Dockerfile", platforms: "linux/amd64" }
|
||||
- { tag: "server-cuda", dockerfile: ".devops/llama-server-cuda.Dockerfile", platforms: "linux/amd64" }
|
||||
- { tag: "full-cuda", dockerfile: ".devops/full-cuda.Dockerfile", platforms: "linux/amd64" }
|
||||
- { tag: "light-rocm", dockerfile: ".devops/llama-cli-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "server-rocm", dockerfile: ".devops/llama-server-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "full-rocm", dockerfile: ".devops/full-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
# Note: the full-rocm image is failing due to a "no space left on device" error. It is disabled for now to allow the workflow to complete.
|
||||
#- { tag: "full-rocm", dockerfile: ".devops/full-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "light-intel", dockerfile: ".devops/llama-cli-intel.Dockerfile", platforms: "linux/amd64" }
|
||||
- { tag: "server-intel", dockerfile: ".devops/llama-server-intel.Dockerfile", platforms: "linux/amd64" }
|
||||
steps:
|
||||
|
|
|
@ -102,7 +102,8 @@ option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM"
|
|||
option(LLAMA_CUDA "llama: use CUDA" OFF)
|
||||
option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF)
|
||||
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
|
||||
option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF)
|
||||
option(LLAMA_CUDA_FORCE_MMQ "llama: always use mmq kernels instead of cuBLAS" OFF)
|
||||
option(LLAMA_CUDA_FORCE_CUBLAS "llama: always use cuBLAS instead of mmq kernels" OFF)
|
||||
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
|
||||
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
|
||||
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
|
||||
|
@ -144,9 +145,6 @@ option(LLAMA_BUILD_SERVER "llama: build server example"
|
|||
option(LLAMA_LASX "llama: enable lasx" ON)
|
||||
option(LLAMA_LSX "llama: enable lsx" ON)
|
||||
|
||||
# add perf arguments
|
||||
option(LLAMA_PERF "llama: enable perf" OFF)
|
||||
|
||||
# Required for relocatable CMake package
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/scripts/build-info.cmake)
|
||||
|
||||
|
@ -419,13 +417,14 @@ if (LLAMA_CUDA)
|
|||
|
||||
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
# 52 == lowest CUDA 12 standard
|
||||
# 60 == f16 CUDA intrinsics
|
||||
# 60 == FP16 CUDA intrinsics
|
||||
# 61 == integer CUDA intrinsics
|
||||
# 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
|
||||
# 70 == FP16 tensor cores
|
||||
# 75 == int8 tensor cores
|
||||
if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
|
||||
#set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
|
||||
endif()
|
||||
endif()
|
||||
|
@ -450,6 +449,9 @@ if (LLAMA_CUDA)
|
|||
if (LLAMA_CUDA_FORCE_MMQ)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
||||
endif()
|
||||
if (LLAMA_CUDA_FORCE_CUBLAS)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
|
||||
endif()
|
||||
if (LLAMA_CUDA_NO_VMM)
|
||||
add_compile_definitions(GGML_CUDA_NO_VMM)
|
||||
endif()
|
||||
|
@ -870,10 +872,6 @@ if (LLAMA_CPU_HBM)
|
|||
target_link_libraries(ggml PUBLIC memkind)
|
||||
endif()
|
||||
|
||||
if (LLAMA_PERF)
|
||||
add_compile_definitions(GGML_PERF)
|
||||
endif()
|
||||
|
||||
function(get_flags CCID CCVER)
|
||||
set(C_FLAGS "")
|
||||
set(CXX_FLAGS "")
|
||||
|
|
6
Makefile
6
Makefile
|
@ -344,9 +344,6 @@ ifdef LLAMA_GPROF
|
|||
MK_CFLAGS += -pg
|
||||
MK_CXXFLAGS += -pg
|
||||
endif
|
||||
ifdef LLAMA_PERF
|
||||
MK_CPPFLAGS += -DGGML_PERF
|
||||
endif
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
|
@ -540,6 +537,9 @@ endif # LLAMA_CUDA_FORCE_DMMV
|
|||
ifdef LLAMA_CUDA_FORCE_MMQ
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ
|
||||
endif # LLAMA_CUDA_FORCE_MMQ
|
||||
ifdef LLAMA_CUDA_FORCE_CUBLAS
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS
|
||||
endif # LLAMA_CUDA_FORCE_CUBLAS
|
||||
ifdef LLAMA_CUDA_DMMV_X
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
||||
else
|
||||
|
|
|
@ -510,8 +510,9 @@ Building the program with BLAS support may lead to some performance improvements
|
|||
|--------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
|
||||
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
|
||||
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
|
||||
| LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of dequantization + matrix multiplication kernels instead of leveraging Math libraries. | |
|
||||
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
|
||||
| LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). Speed for large batch sizes will be worse but VRAM consumption will be lower. |
|
||||
| LLAMA_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models |
|
||||
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
|
||||
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
|
||||
| LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -152,7 +152,6 @@ struct gpt_params {
|
|||
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
||||
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
||||
|
||||
bool embedding = false; // get only sentence embedding
|
||||
bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
|
||||
bool multiline_input = false; // reverse the usage of `\`
|
||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
|
@ -179,6 +178,12 @@ struct gpt_params {
|
|||
std::string mmproj = ""; // path to multimodal projector
|
||||
std::vector<std::string> image; // path to image file(s)
|
||||
|
||||
// embedding
|
||||
bool embedding = false; // get only sentence embedding
|
||||
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||
std::string embd_sep = "\n"; // separator of embendings
|
||||
|
||||
// server params
|
||||
int32_t port = 8080; // server listens on this network port
|
||||
int32_t timeout_read = 600; // http read timeout in seconds
|
||||
|
@ -377,7 +382,7 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz
|
|||
// Embedding utils
|
||||
//
|
||||
|
||||
void llama_embd_normalize(const float * inp, float * out, int n);
|
||||
void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
|
||||
|
||||
float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);
|
||||
|
||||
|
|
|
@ -65,7 +65,8 @@ class Model:
|
|||
# subclasses should define this!
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None):
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
|
||||
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||
if type(self) is Model:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
self.dir_model = dir_model
|
||||
|
@ -80,7 +81,7 @@ class Model:
|
|||
if not self.is_safetensors:
|
||||
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
self.hparams = Model.load_hparams(self.dir_model)
|
||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
|
||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
self.tensor_names = None
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
|
@ -96,7 +97,8 @@ class Model:
|
|||
ftype_lw: str = ftype_up.lower()
|
||||
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
||||
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls):
|
||||
|
@ -332,6 +334,8 @@ class Model:
|
|||
self.gguf_writer.close()
|
||||
|
||||
def write_vocab(self):
|
||||
if len(self.gguf_writer.tensors) != 1:
|
||||
raise ValueError('Splitting the vocabulary is not supported')
|
||||
self.gguf_writer.write_header_to_file(self.fname_out)
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.close()
|
||||
|
@ -2771,6 +2775,124 @@ class DeepseekV2Model(Model):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("T5ForConditionalGeneration")
|
||||
@Model.register("T5WithLMHeadModel")
|
||||
class T5Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.T5
|
||||
|
||||
def set_vocab(self):
|
||||
# to avoid TypeError: Descriptors cannot be created directly
|
||||
# exception when importing sentencepiece_model_pb2
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from sentencepiece import sentencepiece_model_pb2 as model
|
||||
|
||||
tokenizer_path = self.dir_model / 'spiece.model'
|
||||
|
||||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
sentencepiece_model = model.ModelProto()
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
|
||||
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
|
||||
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
|
||||
tokenizer = SentencePieceProcessor()
|
||||
tokenizer.LoadFromFile(str(tokenizer_path))
|
||||
|
||||
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
|
||||
|
||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||
scores: list[float] = [-10000.0] * vocab_size
|
||||
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
|
||||
|
||||
for token_id in range(tokenizer.vocab_size()):
|
||||
piece = tokenizer.IdToPiece(token_id)
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer.GetScore(token_id)
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if tokenizer.IsUnknown(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif tokenizer.IsControl(token_id):
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif tokenizer.IsUnused(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
elif tokenizer.IsByte(token_id):
|
||||
toktype = SentencePieceTokenTypes.BYTE
|
||||
|
||||
tokens[token_id] = text
|
||||
scores[token_id] = score
|
||||
toktypes[token_id] = toktype
|
||||
|
||||
added_tokens_file = self.dir_model / 'added_tokens.json'
|
||||
if added_tokens_file.is_file():
|
||||
with open(added_tokens_file, "r", encoding="utf-8") as f:
|
||||
added_tokens_json = json.load(f)
|
||||
for key in added_tokens_json:
|
||||
token_id = added_tokens_json[key]
|
||||
if (token_id >= vocab_size):
|
||||
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
|
||||
continue
|
||||
|
||||
tokens[token_id] = key.encode("utf-8")
|
||||
scores[token_id] = -1000.0
|
||||
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
|
||||
|
||||
if vocab_size > len(tokens):
|
||||
pad_count = vocab_size - len(tokens)
|
||||
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
|
||||
for i in range(1, pad_count + 1):
|
||||
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
|
||||
scores.append(-1000.0)
|
||||
toktypes.append(SentencePieceTokenTypes.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("t5")
|
||||
self.gguf_writer.add_tokenizer_pre("default")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_add_space_prefix(add_prefix)
|
||||
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
|
||||
if precompiled_charsmap:
|
||||
self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name("T5")
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||
self.gguf_writer.add_key_length(self.hparams["d_kv"])
|
||||
self.gguf_writer.add_value_length(self.hparams["d_kv"])
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
|
||||
self.gguf_writer.add_decoder_start_token_id(self.hparams["decoder_start_token_id"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
|
||||
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
|
||||
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
|
||||
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight":
|
||||
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
@ -2856,10 +2978,44 @@ def parse_args() -> argparse.Namespace:
|
|||
"--verbose", action="store_true",
|
||||
help="increase output verbosity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-max-tensors", type=int, default=0,
|
||||
help="max tensors in each split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-max-size", type=str, default="0",
|
||||
help="max size per split N(M|G)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true",
|
||||
help="only print out a split plan and exit, without writing any new files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-tensor-first-split", action="store_true",
|
||||
help="do not add tensors to the first split (disabled by default)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def split_str_to_n_bytes(split_str: str) -> int:
|
||||
if split_str.endswith("K"):
|
||||
n = int(split_str[:-1]) * 1000
|
||||
elif split_str.endswith("M"):
|
||||
n = int(split_str[:-1]) * 1000 * 1000
|
||||
elif split_str.endswith("G"):
|
||||
n = int(split_str[:-1]) * 1000 * 1000 * 1000
|
||||
elif split_str.isnumeric():
|
||||
n = int(split_str)
|
||||
else:
|
||||
raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G")
|
||||
|
||||
if n < 0:
|
||||
raise ValueError(f"Invalid split size: {split_str}, must be positive")
|
||||
|
||||
return n
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
|
@ -2892,6 +3048,10 @@ def main() -> None:
|
|||
"auto": gguf.LlamaFileType.GUESSED,
|
||||
}
|
||||
|
||||
if args.use_temp_file and (args.split_max_tensors > 0 or args.split_max_size != "0"):
|
||||
logger.error("Error: Cannot use temp file when splitting")
|
||||
sys.exit(1)
|
||||
|
||||
if args.outfile is not None:
|
||||
fname_out = args.outfile
|
||||
else:
|
||||
|
@ -2909,7 +3069,10 @@ def main() -> None:
|
|||
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy, args.model_name)
|
||||
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file,
|
||||
args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split)
|
||||
|
||||
logger.info("Set model parameters")
|
||||
model_instance.set_gguf_parameters()
|
||||
|
@ -2920,13 +3083,13 @@ def main() -> None:
|
|||
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
||||
|
||||
if args.vocab_only:
|
||||
logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
|
||||
logger.info("Exporting model vocab...")
|
||||
model_instance.write_vocab()
|
||||
logger.info("Model vocab successfully exported.")
|
||||
else:
|
||||
logger.info(f"Exporting model to '{model_instance.fname_out}'")
|
||||
logger.info("Exporting model...")
|
||||
model_instance.write()
|
||||
|
||||
logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
|
||||
logger.info("Model successfully exported.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -19,3 +19,43 @@ llama-embedding.exe -m ./path/to/model --log-disable -p "Hello World!" 2>$null
|
|||
```
|
||||
|
||||
The above command will output space-separated float values.
|
||||
|
||||
## extra parameters
|
||||
### --embd-normalize $integer$
|
||||
| $integer$ | description | formula |
|
||||
|-----------|---------------------|---------|
|
||||
| $-1$ | none |
|
||||
| $0$ | max absolute int16 | $\Large{{32760 * x_i} \over\max \lvert x_i\rvert}$
|
||||
| $1$ | taxicab | $\Large{x_i \over\sum \lvert x_i\rvert}$
|
||||
| $2$ | euclidean (default) | $\Large{x_i \over\sqrt{\sum x_i^2}}$
|
||||
| $>2$ | p-norm | $\Large{x_i \over\sqrt[p]{\sum \lvert x_i\rvert^p}}$
|
||||
|
||||
### --embd-output-format $'string'$
|
||||
| $'string'$ | description | |
|
||||
|------------|------------------------------|--|
|
||||
| '' | same as before | (default)
|
||||
| 'array' | single embeddings | $[[x_1,...,x_n]]$
|
||||
| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
|
||||
| 'json' | openai style |
|
||||
| 'json+' | add cosine similarity matrix |
|
||||
|
||||
### --embd-separator $"string"$
|
||||
| $"string"$ | |
|
||||
|--------------|-|
|
||||
| "\n" | (default)
|
||||
| "<#embSep#>" | for exemple
|
||||
| "<#sep#>" | other exemple
|
||||
|
||||
## examples
|
||||
### Unix-based systems (Linux, macOS, etc.):
|
||||
|
||||
```bash
|
||||
./embedding -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
|
||||
```
|
||||
|
||||
### Windows:
|
||||
|
||||
```powershell
|
||||
embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
|
||||
```
|
||||
|
||||
|
|
|
@ -7,13 +7,19 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static std::vector<std::string> split_lines(const std::string & s) {
|
||||
std::string line;
|
||||
static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
|
||||
std::vector<std::string> lines;
|
||||
std::stringstream ss(s);
|
||||
while (std::getline(ss, line)) {
|
||||
lines.push_back(line);
|
||||
size_t start = 0;
|
||||
size_t end = s.find(separator);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
lines.push_back(s.substr(start, end - start));
|
||||
start = end + separator.length();
|
||||
end = s.find(separator, start);
|
||||
}
|
||||
|
||||
lines.push_back(s.substr(start)); // Add the last part
|
||||
|
||||
return lines;
|
||||
}
|
||||
|
||||
|
@ -24,7 +30,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
|||
}
|
||||
}
|
||||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
|
@ -44,13 +50,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||
|
||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||
//TODO: I would also add a parameter here to enable normalization or not.
|
||||
/*fprintf(stdout, "unnormalized_embedding:");
|
||||
for (int hh = 0; hh < n_embd; hh++) {
|
||||
fprintf(stdout, "%9.6f ", embd[hh]);
|
||||
}
|
||||
fprintf(stdout, "\n");*/
|
||||
llama_embd_normalize(embd, out, n_embd);
|
||||
llama_embd_normalize(embd, out, n_embd, embd_norm);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// split the prompt into lines
|
||||
std::vector<std::string> prompts = split_lines(params.prompt);
|
||||
std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
|
||||
|
||||
// max batch size
|
||||
const uint64_t n_batch = params.n_batch;
|
||||
|
@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
|
|||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch) {
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
||||
llama_batch_clear(batch);
|
||||
p += s;
|
||||
s = 0;
|
||||
|
@ -183,29 +183,78 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// final batch
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
||||
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
fprintf(stdout, "\n");
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
fprintf(stdout, "embedding %d: ", j);
|
||||
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
|
||||
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
|
||||
}
|
||||
if (params.embd_out.empty()) {
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
// print cosine similarity matrix
|
||||
if (n_prompts > 1) {
|
||||
fprintf(stdout, "\n");
|
||||
printf("cosine similarity matrix:\n\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
fprintf(stdout, "%6.2f ", sim);
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
fprintf(stdout, "embedding %d: ", j);
|
||||
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
|
||||
} else {
|
||||
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
|
||||
}
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
// print cosine similarity matrix
|
||||
if (n_prompts > 1) {
|
||||
fprintf(stdout, "\n");
|
||||
printf("cosine similarity matrix:\n\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
fprintf(stdout, "%6.6s ", prompts[i].c_str());
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
fprintf(stdout, "%6.2f ", sim);
|
||||
}
|
||||
fprintf(stdout, "%1.10s", prompts[i].c_str());
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
|
||||
const bool notArray = params.embd_out != "array";
|
||||
|
||||
fprintf(stdout, notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
|
||||
for (int j = 0;;) { // at least one iteration (one prompt)
|
||||
if (notArray) fprintf(stdout, " {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
|
||||
fprintf(stdout, "[");
|
||||
for (int i = 0;;) { // at least one iteration (n_embd > 0)
|
||||
fprintf(stdout, params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
|
||||
i++;
|
||||
if (i < n_embd) fprintf(stdout, ","); else break;
|
||||
}
|
||||
fprintf(stdout, notArray ? "]\n }" : "]");
|
||||
j++;
|
||||
if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break;
|
||||
}
|
||||
fprintf(stdout, notArray ? "\n ]" : "]\n");
|
||||
|
||||
if (params.embd_out == "json+" && n_prompts > 1) {
|
||||
fprintf(stdout, ",\n \"cosineSimilarity\": [\n");
|
||||
for (int i = 0;;) { // at least two iteration (n_prompts > 1)
|
||||
fprintf(stdout, " [");
|
||||
for (int j = 0;;) { // at least two iteration (n_prompts > 1)
|
||||
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
fprintf(stdout, "%6.2f", sim);
|
||||
j++;
|
||||
if (j < n_prompts) fprintf(stdout, ", "); else break;
|
||||
}
|
||||
fprintf(stdout, " ]");
|
||||
i++;
|
||||
if (i < n_prompts) fprintf(stdout, ",\n"); else break;
|
||||
}
|
||||
fprintf(stdout, "\n ]");
|
||||
}
|
||||
|
||||
if (notArray) fprintf(stdout, "\n}\n");
|
||||
}
|
||||
|
||||
// clean up
|
||||
|
|
96
ggml-cuda.cu
96
ggml-cuda.cu
|
@ -152,16 +152,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
|
||||
|
||||
int64_t total_vram = 0;
|
||||
#if defined(GGML_CUDA_FORCE_MMQ)
|
||||
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
|
||||
#ifdef GGML_CUDA_FORCE_MMQ
|
||||
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
|
||||
#else
|
||||
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
|
||||
#endif
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
|
||||
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
|
||||
#endif // GGML_CUDA_FORCE_MMQ
|
||||
#ifdef GGML_CUDA_FORCE_CUBLAS
|
||||
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
|
||||
#else
|
||||
GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
|
||||
#endif
|
||||
GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
|
||||
#endif // GGML_CUDA_FORCE_CUBLAS
|
||||
GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
int device_vmm = 0;
|
||||
|
@ -1873,9 +1873,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
|
||||
|
||||
int64_t min_compute_capability = INT_MAX;
|
||||
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
bool use_mul_mat_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
bool any_gpus_with_slow_fp16 = false;
|
||||
|
||||
bool any_pascal_with_slow_fp16 = false;
|
||||
if (split) {
|
||||
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
||||
auto & tensor_split = buft_ctx->tensor_split;
|
||||
|
@ -1885,55 +1893,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
continue;
|
||||
}
|
||||
|
||||
if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
|
||||
min_compute_capability = ggml_cuda_info().devices[id].cc;
|
||||
}
|
||||
if (ggml_cuda_info().devices[id].cc == 610) {
|
||||
any_pascal_with_slow_fp16 = true;
|
||||
}
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
||||
}
|
||||
} else {
|
||||
min_compute_capability = ggml_cuda_info().devices[ctx.device].cc;
|
||||
any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610;
|
||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||
use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
||||
}
|
||||
|
||||
// check data types and tensor shapes for custom matrix multiplication kernels:
|
||||
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
|
||||
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
|
||||
bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
|
||||
|
||||
#ifdef CUDA_USE_TENSOR_CORES
|
||||
use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
|
||||
#endif // CUDA_USE_TENSOR_CORES
|
||||
|
||||
#else
|
||||
|
||||
// fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
|
||||
const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
|
||||
|
||||
// mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
|
||||
use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
|
||||
use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A;
|
||||
|
||||
#ifdef CUDA_USE_TENSOR_CORES
|
||||
// when tensor cores are available, use them for large batch size
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
|
||||
use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
||||
#endif // CUDA_USE_TENSOR_CORES
|
||||
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
// if mmvq is available it's a better choice than dmmv:
|
||||
#ifndef GGML_CUDA_FORCE_DMMV
|
||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||
|
@ -1947,21 +1918,22 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// KQ single-batch
|
||||
if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// FP32 precision KQ single-batch for batch size 1 without FlashAttention
|
||||
ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
|
||||
} else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV single-batch
|
||||
} else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// FP32 precision KQV single-batch for batch size 1 without FlashAttention
|
||||
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
||||
} else if (use_dequantize_mul_mat_vec) {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
|
||||
} else if (use_mul_mat_vec_q) {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
|
||||
} else if (use_mul_mat_q) {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
|
||||
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
// KQ + KQV multi-batch without FlashAttention
|
||||
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
||||
} else {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
|
||||
}
|
||||
|
|
|
@ -146,23 +146,6 @@
|
|||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
|
||||
|
||||
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
|
||||
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
|
||||
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
|
||||
// - 7B quantum model: +100-200 MB
|
||||
// - 13B quantum model: +200-400 MB
|
||||
//
|
||||
//#define GGML_CUDA_FORCE_MMQ
|
||||
|
||||
// TODO: improve this to be correct for more hardware
|
||||
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
|
||||
#if !defined(GGML_CUDA_FORCE_MMQ)
|
||||
#define CUDA_USE_TENSOR_CORES
|
||||
#endif
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
|
||||
#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
|
||||
|
||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
@ -343,15 +326,15 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
|
|||
#define INT8_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
static constexpr bool fast_fp16_available(const int cc) {
|
||||
return cc >= CC_PASCAL && cc != 610;
|
||||
}
|
||||
|
||||
static bool fp16_mma_available(const int cc) {
|
||||
static constexpr bool fp16_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
|
||||
}
|
||||
|
||||
static bool int8_mma_available(const int cc) {
|
||||
static constexpr bool int8_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
|
||||
}
|
||||
|
||||
|
@ -643,19 +626,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
|
|||
static constexpr int qi = QI3_S;
|
||||
};
|
||||
|
||||
static int get_mmq_x_max_host(const int cc) {
|
||||
#ifdef CUDA_USE_TENSOR_CORES
|
||||
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
|
||||
#else
|
||||
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
|
||||
#endif // CUDA_USE_TENSOR_CORES
|
||||
}
|
||||
|
||||
// Round rows to this value for --split-mode row:
|
||||
static int get_mmq_y_host(const int cc) {
|
||||
return cc >= CC_VOLTA ? 128 : 64;
|
||||
}
|
||||
|
||||
//////////////////////
|
||||
|
||||
struct ggml_cuda_device_info {
|
||||
|
|
|
@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
|
|||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||
#if defined(INT8_MMA_AVAILABLE)
|
||||
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
||||
}
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_A_I16K8 {
|
||||
|
@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
|
|||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||
#if defined(INT8_MMA_AVAILABLE)
|
||||
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
||||
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
||||
}
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K4 {
|
||||
|
@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
|
|||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
||||
const int * xs = xs0 + (threadIdx.x%J)*stride;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
|
||||
: "+r"(x[0])
|
||||
: "l"(xs));
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
||||
}
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K8 {
|
||||
|
@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
|
|||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
||||
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
|
||||
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
||||
}
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_C_I16J8 {
|
||||
|
|
|
@ -69,7 +69,13 @@ void ggml_cuda_op_mul_mat_q(
|
|||
GGML_UNUSED(src1_ddf_i);
|
||||
}
|
||||
|
||||
bool ggml_cuda_supports_mmq(enum ggml_type type) {
|
||||
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
#ifdef GGML_CUDA_FORCE_CUBLAS
|
||||
return false;
|
||||
#endif // GGML_CUDA_FORCE_CUBLAS
|
||||
|
||||
bool mmq_supported;
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
|
@ -81,8 +87,32 @@ bool ggml_cuda_supports_mmq(enum ggml_type type) {
|
|||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
mmq_supported = true;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
mmq_supported = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!mmq_supported) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (int8_mma_available(cc)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cc < MIN_CC_DP4A) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_FORCE_MMQ
|
||||
return true;
|
||||
#endif //GGML_CUDA_FORCE_MMQ
|
||||
|
||||
if (cc < CC_OFFSET_AMD) {
|
||||
return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
|
1461
ggml-cuda/mmq.cuh
1461
ggml-cuda/mmq.cuh
File diff suppressed because it is too large
Load diff
|
@ -1,5 +1,7 @@
|
|||
#include "common.cuh"
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
|
|
|
@ -513,8 +513,8 @@ static size_t vk_skip_checks;
|
|||
static size_t vk_output_tensor;
|
||||
|
||||
static void ggml_vk_print_tensor(ggml_backend * ctx, const ggml_tensor * tensor, const char * name);
|
||||
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor);
|
||||
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor);
|
||||
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
|
||||
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
|
||||
#endif
|
||||
|
||||
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
@ -5644,7 +5644,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor){
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor){
|
||||
ggml_tensor_extra_gpu * extra = nullptr;
|
||||
|
||||
switch (tensor->op) {
|
||||
|
@ -5697,17 +5697,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
|
|||
return false;
|
||||
}
|
||||
|
||||
if (params->ith != 0) {
|
||||
return true;
|
||||
}
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return true;
|
||||
}
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
|
||||
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
ggml_vk_check_results_0(ctx, params, tensor);
|
||||
ggml_vk_check_results_0(ctx, tensor);
|
||||
#endif
|
||||
|
||||
vk_context& subctx = ctx->gc.contexts[extra->ctx_idx];
|
||||
|
@ -6214,9 +6207,6 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
|
|||
ggml_vk_build_graph(ctx,cgraph->nodes[i], i == last_node);
|
||||
}
|
||||
|
||||
ggml_compute_params params = {};
|
||||
params.type = GGML_TASK_TYPE_COMPUTE;
|
||||
params.ith = 0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
|
@ -6224,13 +6214,13 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
|
|||
continue;
|
||||
}
|
||||
|
||||
bool ok = ggml_vk_compute_forward(ctx, ¶ms, node);
|
||||
bool ok = ggml_vk_compute_forward(ctx, node);
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
else {
|
||||
ggml_vk_check_results_1(ctx, ¶ms, node);
|
||||
ggml_vk_check_results_1(ctx, node);
|
||||
}
|
||||
#endif
|
||||
GGML_ASSERT(ok);
|
||||
|
@ -6600,11 +6590,8 @@ void * comp_result;
|
|||
size_t comp_size;
|
||||
size_t comp_nb[GGML_MAX_DIMS];
|
||||
size_t check_counter = 0;
|
||||
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor) {
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
|
||||
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
|
||||
if (tensor->op == GGML_OP_TRANSPOSE) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -6908,11 +6895,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
|
|||
ggml_free(ggml_ctx);
|
||||
}
|
||||
|
||||
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor) {
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
|
||||
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
|
||||
if (tensor->op == GGML_OP_TRANSPOSE) {
|
||||
return;
|
||||
}
|
||||
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
||||
|
|
35
ggml.h
35
ggml.h
|
@ -591,11 +591,7 @@ extern "C" {
|
|||
struct ggml_tensor * grad;
|
||||
struct ggml_tensor * src[GGML_MAX_SRC];
|
||||
|
||||
// performance
|
||||
int perf_runs;
|
||||
int64_t perf_cycles;
|
||||
int64_t perf_time_us;
|
||||
|
||||
// source tensor and offset for views
|
||||
struct ggml_tensor * view_src;
|
||||
size_t view_offs;
|
||||
|
||||
|
@ -605,7 +601,7 @@ extern "C" {
|
|||
|
||||
void * extra; // extra things e.g. for ggml-cuda.cu
|
||||
|
||||
char padding[8];
|
||||
// char padding[4];
|
||||
};
|
||||
|
||||
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
||||
|
@ -652,11 +648,6 @@ extern "C" {
|
|||
struct ggml_hash_set visited_hash_table;
|
||||
|
||||
enum ggml_cgraph_eval_order order;
|
||||
|
||||
// performance
|
||||
int perf_runs;
|
||||
int64_t perf_cycles;
|
||||
int64_t perf_time_us;
|
||||
};
|
||||
|
||||
// scratch buffer
|
||||
|
@ -673,28 +664,6 @@ extern "C" {
|
|||
bool no_alloc; // don't allocate memory for the tensor data
|
||||
};
|
||||
|
||||
|
||||
// compute types
|
||||
|
||||
// NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
|
||||
// This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
|
||||
enum ggml_task_type {
|
||||
GGML_TASK_TYPE_INIT = 0,
|
||||
GGML_TASK_TYPE_COMPUTE,
|
||||
GGML_TASK_TYPE_FINALIZE,
|
||||
};
|
||||
|
||||
struct ggml_compute_params {
|
||||
enum ggml_task_type type;
|
||||
|
||||
// ith = thread index, nth = number of threads
|
||||
int ith, nth;
|
||||
|
||||
// work buffer for all threads
|
||||
size_t wsize;
|
||||
void * wdata;
|
||||
};
|
||||
|
||||
// numa strategies
|
||||
enum ggml_numa_strategy {
|
||||
GGML_NUMA_STRATEGY_DISABLED = 0,
|
||||
|
|
|
@ -49,6 +49,7 @@ class Keys:
|
|||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
|
@ -62,6 +63,7 @@ class Keys:
|
|||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
|
@ -73,6 +75,11 @@ class Keys:
|
|||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
||||
|
||||
class Split:
|
||||
LLM_KV_SPLIT_NO = "split.no"
|
||||
LLM_KV_SPLIT_COUNT = "split.count"
|
||||
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
|
||||
|
||||
class SSM:
|
||||
CONV_KERNEL = "{arch}.ssm.conv_kernel"
|
||||
INNER_SIZE = "{arch}.ssm.inner_size"
|
||||
|
@ -80,33 +87,35 @@ class Keys:
|
|||
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
PRE = "tokenizer.ggml.pre"
|
||||
LIST = "tokenizer.ggml.tokens"
|
||||
TOKEN_TYPE = "tokenizer.ggml.token_type"
|
||||
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
|
||||
SCORES = "tokenizer.ggml.scores"
|
||||
MERGES = "tokenizer.ggml.merges"
|
||||
BOS_ID = "tokenizer.ggml.bos_token_id"
|
||||
EOS_ID = "tokenizer.ggml.eos_token_id"
|
||||
UNK_ID = "tokenizer.ggml.unknown_token_id"
|
||||
SEP_ID = "tokenizer.ggml.seperator_token_id"
|
||||
PAD_ID = "tokenizer.ggml.padding_token_id"
|
||||
CLS_ID = "tokenizer.ggml.cls_token_id"
|
||||
MASK_ID = "tokenizer.ggml.mask_token_id"
|
||||
ADD_BOS = "tokenizer.ggml.add_bos_token"
|
||||
ADD_EOS = "tokenizer.ggml.add_eos_token"
|
||||
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
|
||||
HF_JSON = "tokenizer.huggingface.json"
|
||||
RWKV = "tokenizer.rwkv.world"
|
||||
CHAT_TEMPLATE = "tokenizer.chat_template"
|
||||
CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}"
|
||||
CHAT_TEMPLATES = "tokenizer.chat_templates"
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
PRE = "tokenizer.ggml.pre"
|
||||
LIST = "tokenizer.ggml.tokens"
|
||||
TOKEN_TYPE = "tokenizer.ggml.token_type"
|
||||
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
|
||||
SCORES = "tokenizer.ggml.scores"
|
||||
MERGES = "tokenizer.ggml.merges"
|
||||
BOS_ID = "tokenizer.ggml.bos_token_id"
|
||||
EOS_ID = "tokenizer.ggml.eos_token_id"
|
||||
UNK_ID = "tokenizer.ggml.unknown_token_id"
|
||||
SEP_ID = "tokenizer.ggml.seperator_token_id"
|
||||
PAD_ID = "tokenizer.ggml.padding_token_id"
|
||||
CLS_ID = "tokenizer.ggml.cls_token_id"
|
||||
MASK_ID = "tokenizer.ggml.mask_token_id"
|
||||
ADD_BOS = "tokenizer.ggml.add_bos_token"
|
||||
ADD_EOS = "tokenizer.ggml.add_eos_token"
|
||||
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
|
||||
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
|
||||
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"
|
||||
HF_JSON = "tokenizer.huggingface.json"
|
||||
RWKV = "tokenizer.rwkv.world"
|
||||
CHAT_TEMPLATE = "tokenizer.chat_template"
|
||||
CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}"
|
||||
CHAT_TEMPLATES = "tokenizer.chat_templates"
|
||||
# FIM/Infill special tokens constants
|
||||
PREFIX_ID = "tokenizer.ggml.prefix_token_id"
|
||||
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
|
||||
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
|
||||
EOT_ID = "tokenizer.ggml.eot_token_id"
|
||||
PREFIX_ID = "tokenizer.ggml.prefix_token_id"
|
||||
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
|
||||
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
|
||||
EOT_ID = "tokenizer.ggml.eot_token_id"
|
||||
|
||||
|
||||
#
|
||||
|
@ -115,94 +124,123 @@ class Keys:
|
|||
|
||||
|
||||
class MODEL_ARCH(IntEnum):
|
||||
LLAMA = auto()
|
||||
FALCON = auto()
|
||||
BAICHUAN = auto()
|
||||
GROK = auto()
|
||||
GPT2 = auto()
|
||||
GPTJ = auto()
|
||||
GPTNEOX = auto()
|
||||
MPT = auto()
|
||||
STARCODER = auto()
|
||||
REFACT = auto()
|
||||
BERT = auto()
|
||||
NOMIC_BERT = auto()
|
||||
LLAMA = auto()
|
||||
FALCON = auto()
|
||||
BAICHUAN = auto()
|
||||
GROK = auto()
|
||||
GPT2 = auto()
|
||||
GPTJ = auto()
|
||||
GPTNEOX = auto()
|
||||
MPT = auto()
|
||||
STARCODER = auto()
|
||||
REFACT = auto()
|
||||
BERT = auto()
|
||||
NOMIC_BERT = auto()
|
||||
JINA_BERT_V2 = auto()
|
||||
BLOOM = auto()
|
||||
STABLELM = auto()
|
||||
QWEN = auto()
|
||||
QWEN2 = auto()
|
||||
QWEN2MOE = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PLAMO = auto()
|
||||
CODESHELL = auto()
|
||||
ORION = auto()
|
||||
INTERNLM2 = auto()
|
||||
MINICPM = auto()
|
||||
GEMMA = auto()
|
||||
STARCODER2 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
BITNET = auto()
|
||||
BLOOM = auto()
|
||||
STABLELM = auto()
|
||||
QWEN = auto()
|
||||
QWEN2 = auto()
|
||||
QWEN2MOE = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PLAMO = auto()
|
||||
CODESHELL = auto()
|
||||
ORION = auto()
|
||||
INTERNLM2 = auto()
|
||||
MINICPM = auto()
|
||||
GEMMA = auto()
|
||||
STARCODER2 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
BITNET = auto()
|
||||
T5 = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
TOKEN_EMBD = auto()
|
||||
TOKEN_EMBD_NORM = auto()
|
||||
TOKEN_TYPES = auto()
|
||||
POS_EMBD = auto()
|
||||
OUTPUT = auto()
|
||||
OUTPUT_NORM = auto()
|
||||
ROPE_FREQS = auto()
|
||||
ROPE_FACTORS_LONG = auto()
|
||||
ROPE_FACTORS_SHORT = auto()
|
||||
ATTN_Q = auto()
|
||||
ATTN_K = auto()
|
||||
ATTN_V = auto()
|
||||
ATTN_QKV = auto()
|
||||
ATTN_OUT = auto()
|
||||
ATTN_NORM = auto()
|
||||
ATTN_NORM_2 = auto()
|
||||
ATTN_OUT_NORM = auto()
|
||||
ATTN_ROT_EMBD = auto()
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
FFN_GATE = auto()
|
||||
FFN_DOWN = auto()
|
||||
FFN_UP = auto()
|
||||
FFN_ACT = auto()
|
||||
FFN_NORM_EXP = auto()
|
||||
FFN_GATE_EXP = auto()
|
||||
FFN_DOWN_EXP = auto()
|
||||
FFN_UP_EXP = auto()
|
||||
FFN_GATE_SHEXP = auto()
|
||||
FFN_DOWN_SHEXP = auto()
|
||||
FFN_UP_SHEXP = auto()
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
SSM_IN = auto()
|
||||
SSM_CONV1D = auto()
|
||||
SSM_X = auto()
|
||||
SSM_DT = auto()
|
||||
SSM_A = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
ATTN_Q_A = auto()
|
||||
ATTN_Q_B = auto()
|
||||
ATTN_KV_A_MQA = auto()
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
FFN_SUB_NORM = auto()
|
||||
ATTN_SUB_NORM = auto()
|
||||
TOKEN_EMBD = auto()
|
||||
TOKEN_EMBD_NORM = auto()
|
||||
TOKEN_TYPES = auto()
|
||||
POS_EMBD = auto()
|
||||
OUTPUT = auto()
|
||||
OUTPUT_NORM = auto()
|
||||
ROPE_FREQS = auto()
|
||||
ROPE_FACTORS_LONG = auto()
|
||||
ROPE_FACTORS_SHORT = auto()
|
||||
ATTN_Q = auto()
|
||||
ATTN_K = auto()
|
||||
ATTN_V = auto()
|
||||
ATTN_QKV = auto()
|
||||
ATTN_OUT = auto()
|
||||
ATTN_NORM = auto()
|
||||
ATTN_NORM_2 = auto()
|
||||
ATTN_OUT_NORM = auto()
|
||||
ATTN_ROT_EMBD = auto()
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
FFN_GATE = auto()
|
||||
FFN_DOWN = auto()
|
||||
FFN_UP = auto()
|
||||
FFN_ACT = auto()
|
||||
FFN_NORM_EXP = auto()
|
||||
FFN_GATE_EXP = auto()
|
||||
FFN_DOWN_EXP = auto()
|
||||
FFN_UP_EXP = auto()
|
||||
FFN_GATE_SHEXP = auto()
|
||||
FFN_DOWN_SHEXP = auto()
|
||||
FFN_UP_SHEXP = auto()
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
SSM_IN = auto()
|
||||
SSM_CONV1D = auto()
|
||||
SSM_X = auto()
|
||||
SSM_DT = auto()
|
||||
SSM_A = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
ATTN_Q_A = auto()
|
||||
ATTN_Q_B = auto()
|
||||
ATTN_KV_A_MQA = auto()
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
FFN_SUB_NORM = auto()
|
||||
ATTN_SUB_NORM = auto()
|
||||
DEC_ATTN_NORM = auto()
|
||||
DEC_ATTN_Q = auto()
|
||||
DEC_ATTN_K = auto()
|
||||
DEC_ATTN_V = auto()
|
||||
DEC_ATTN_OUT = auto()
|
||||
DEC_ATTN_REL_B = auto()
|
||||
DEC_CROSS_ATTN_NORM = auto()
|
||||
DEC_CROSS_ATTN_Q = auto()
|
||||
DEC_CROSS_ATTN_K = auto()
|
||||
DEC_CROSS_ATTN_V = auto()
|
||||
DEC_CROSS_ATTN_OUT = auto()
|
||||
DEC_CROSS_ATTN_REL_B = auto()
|
||||
DEC_FFN_NORM = auto()
|
||||
DEC_FFN_GATE = auto()
|
||||
DEC_FFN_DOWN = auto()
|
||||
DEC_FFN_UP = auto()
|
||||
DEC_OUTPUT_NORM = auto()
|
||||
ENC_ATTN_NORM = auto()
|
||||
ENC_ATTN_Q = auto()
|
||||
ENC_ATTN_K = auto()
|
||||
ENC_ATTN_V = auto()
|
||||
ENC_ATTN_OUT = auto()
|
||||
ENC_ATTN_REL_B = auto()
|
||||
ENC_FFN_NORM = auto()
|
||||
ENC_FFN_GATE = auto()
|
||||
ENC_FFN_DOWN = auto()
|
||||
ENC_FFN_UP = auto()
|
||||
ENC_OUTPUT_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
|
@ -241,59 +279,88 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
MODEL_ARCH.T5: "t5",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
|
||||
MODEL_TENSOR.TOKEN_TYPES: "token_types",
|
||||
MODEL_TENSOR.POS_EMBD: "position_embd",
|
||||
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
||||
MODEL_TENSOR.OUTPUT: "output",
|
||||
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
|
||||
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
|
||||
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
|
||||
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
|
||||
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
||||
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
|
||||
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
|
||||
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
|
||||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
|
||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
|
||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
||||
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
|
||||
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
|
||||
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
|
||||
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
|
||||
MODEL_TENSOR.TOKEN_TYPES: "token_types",
|
||||
MODEL_TENSOR.POS_EMBD: "position_embd",
|
||||
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
||||
MODEL_TENSOR.OUTPUT: "output",
|
||||
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
|
||||
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
|
||||
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
|
||||
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
|
||||
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
||||
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
|
||||
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
|
||||
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
|
||||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
|
||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
|
||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
||||
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
|
||||
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
|
||||
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
|
||||
MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm",
|
||||
MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o",
|
||||
MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b",
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm",
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q",
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k",
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v",
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o",
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b",
|
||||
MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm",
|
||||
MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm",
|
||||
MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o",
|
||||
MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b",
|
||||
MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
|
@ -829,6 +896,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ATTN_SUB_NORM,
|
||||
MODEL_TENSOR.FFN_SUB_NORM,
|
||||
],
|
||||
MODEL_ARCH.T5: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.DEC_ATTN_NORM,
|
||||
MODEL_TENSOR.DEC_ATTN_Q,
|
||||
MODEL_TENSOR.DEC_ATTN_K,
|
||||
MODEL_TENSOR.DEC_ATTN_V,
|
||||
MODEL_TENSOR.DEC_ATTN_OUT,
|
||||
MODEL_TENSOR.DEC_ATTN_REL_B,
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_NORM,
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_Q,
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_K,
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_V,
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_OUT,
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_REL_B,
|
||||
MODEL_TENSOR.DEC_FFN_NORM,
|
||||
MODEL_TENSOR.DEC_FFN_GATE,
|
||||
MODEL_TENSOR.DEC_FFN_DOWN,
|
||||
MODEL_TENSOR.DEC_FFN_UP,
|
||||
MODEL_TENSOR.DEC_OUTPUT_NORM,
|
||||
MODEL_TENSOR.ENC_ATTN_NORM,
|
||||
MODEL_TENSOR.ENC_ATTN_Q,
|
||||
MODEL_TENSOR.ENC_ATTN_K,
|
||||
MODEL_TENSOR.ENC_ATTN_V,
|
||||
MODEL_TENSOR.ENC_ATTN_OUT,
|
||||
MODEL_TENSOR.ENC_ATTN_REL_B,
|
||||
MODEL_TENSOR.ENC_FFN_NORM,
|
||||
MODEL_TENSOR.ENC_FFN_GATE,
|
||||
MODEL_TENSOR.ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.ENC_FFN_UP,
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import struct
|
|||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from io import BufferedWriter
|
||||
from typing import IO, Any, Sequence, Mapping
|
||||
from string import ascii_letters, digits
|
||||
|
@ -31,6 +32,9 @@ from .quants import quant_shape_from_byte_shape
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
shape: Sequence[int]
|
||||
|
@ -55,11 +59,11 @@ class WriterState(Enum):
|
|||
|
||||
|
||||
class GGUFWriter:
|
||||
fout: BufferedWriter | None
|
||||
path: os.PathLike[str] | str | None
|
||||
fout: list[BufferedWriter] | None
|
||||
path: Path | None
|
||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||
tensors: dict[str, TensorInfo]
|
||||
kv_data: dict[str, GGUFValue]
|
||||
tensors: list[dict[str, TensorInfo]]
|
||||
kv_data: list[dict[str, GGUFValue]]
|
||||
state: WriterState
|
||||
_simple_value_packing = {
|
||||
GGUFValueType.UINT8: "B",
|
||||
|
@ -76,26 +80,38 @@ class GGUFWriter:
|
|||
}
|
||||
|
||||
def __init__(
|
||||
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False,
|
||||
endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
|
||||
):
|
||||
self.fout = None
|
||||
self.path = path
|
||||
self.path = Path(path) if path else None
|
||||
self.arch = arch
|
||||
self.endianess = endianess
|
||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||
self.use_temp_file = use_temp_file
|
||||
self.temp_file = None
|
||||
self.tensors = dict()
|
||||
self.kv_data = dict()
|
||||
self.tensors = [{}]
|
||||
self.kv_data = [{}]
|
||||
self.split_max_tensors = split_max_tensors
|
||||
self.split_max_size = split_max_size
|
||||
self.dry_run = dry_run
|
||||
self.small_first_shard = small_first_shard
|
||||
logger.info("gguf: This GGUF file is for {0} Endian only".format(
|
||||
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
||||
))
|
||||
self.state = WriterState.NO_FILE
|
||||
|
||||
if self.small_first_shard:
|
||||
self.tensors.append({})
|
||||
|
||||
self.add_architecture()
|
||||
|
||||
def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
def format_shard_names(self, path: Path) -> list[Path]:
|
||||
if len(self.tensors) == 1:
|
||||
return [path]
|
||||
return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
|
||||
|
||||
def open_output_file(self, path: Path | None = None) -> None:
|
||||
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
|
||||
# allow calling this multiple times as long as the path is the same
|
||||
return
|
||||
|
@ -106,22 +122,58 @@ class GGUFWriter:
|
|||
self.path = path
|
||||
|
||||
if self.path is not None:
|
||||
if self.fout is not None:
|
||||
self.fout.close()
|
||||
self.fout = open(self.path, "wb")
|
||||
filenames = self.print_plan()
|
||||
self.fout = [open(filename, "wb") for filename in filenames]
|
||||
self.state = WriterState.EMPTY
|
||||
|
||||
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
def print_plan(self) -> list[Path]:
|
||||
logger.info("Writing the following files:")
|
||||
assert self.path is not None
|
||||
filenames = self.format_shard_names(self.path)
|
||||
assert len(filenames) == len(self.tensors)
|
||||
for name, tensors in zip(filenames, self.tensors):
|
||||
logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
|
||||
|
||||
if self.dry_run:
|
||||
logger.info("Dry run, not writing files")
|
||||
exit()
|
||||
|
||||
return filenames
|
||||
|
||||
def add_shard_kv_data(self) -> None:
|
||||
if len(self.tensors) == 1:
|
||||
return
|
||||
|
||||
total_tensors = sum(len(t) for t in self.tensors)
|
||||
assert self.fout is not None
|
||||
total_splits = len(self.fout)
|
||||
self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
|
||||
for i, kv_data in enumerate(self.kv_data):
|
||||
kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
|
||||
kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
|
||||
kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
|
||||
|
||||
def write_header_to_file(self, path: Path | None = None) -> None:
|
||||
if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
|
||||
logger.warning("Model fails split requirements, not splitting")
|
||||
|
||||
self.open_output_file(path)
|
||||
|
||||
if self.state is not WriterState.EMPTY:
|
||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||
|
||||
self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
|
||||
self._write_packed("I", GGUF_VERSION)
|
||||
self._write_packed("Q", len(self.tensors))
|
||||
self._write_packed("Q", len(self.kv_data))
|
||||
self.flush()
|
||||
assert self.fout is not None
|
||||
assert len(self.fout) == len(self.tensors)
|
||||
assert len(self.kv_data) == 1
|
||||
|
||||
self.add_shard_kv_data()
|
||||
|
||||
for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
|
||||
fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True))
|
||||
fout.write(self._pack("I", GGUF_VERSION))
|
||||
fout.write(self._pack("Q", len(tensors)))
|
||||
fout.write(self._pack("Q", len(kv_data)))
|
||||
fout.flush()
|
||||
self.state = WriterState.HEADER
|
||||
|
||||
def write_kv_data_to_file(self) -> None:
|
||||
|
@ -129,13 +181,15 @@ class GGUFWriter:
|
|||
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
||||
assert self.fout is not None
|
||||
|
||||
kv_data = bytearray()
|
||||
for fout, kv_data in zip(self.fout, self.kv_data):
|
||||
kv_bytes = bytearray()
|
||||
|
||||
for key, val in self.kv_data.items():
|
||||
kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_data += self._pack_val(val.value, val.type, add_vtype=True)
|
||||
for key, val in kv_data.items():
|
||||
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
|
||||
|
||||
fout.write(kv_bytes)
|
||||
|
||||
self.fout.write(kv_data)
|
||||
self.flush()
|
||||
self.state = WriterState.KV_DATA
|
||||
|
||||
|
@ -144,28 +198,29 @@ class GGUFWriter:
|
|||
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
||||
assert self.fout is not None
|
||||
|
||||
ti_data = bytearray()
|
||||
offset_tensor = 0
|
||||
for fout, tensors in zip(self.fout, self.tensors):
|
||||
ti_data = bytearray()
|
||||
offset_tensor = 0
|
||||
|
||||
for name, ti in self.tensors.items():
|
||||
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||
n_dims = len(ti.shape)
|
||||
ti_data += self._pack("I", n_dims)
|
||||
for i in range(n_dims):
|
||||
ti_data += self._pack("Q", ti.shape[n_dims - 1 - i])
|
||||
ti_data += self._pack("I", ti.dtype)
|
||||
ti_data += self._pack("Q", offset_tensor)
|
||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
|
||||
for name, ti in tensors.items():
|
||||
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||
n_dims = len(ti.shape)
|
||||
ti_data += self._pack("I", n_dims)
|
||||
for j in range(n_dims):
|
||||
ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
|
||||
ti_data += self._pack("I", ti.dtype)
|
||||
ti_data += self._pack("Q", offset_tensor)
|
||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
|
||||
|
||||
self.fout.write(ti_data)
|
||||
self.flush()
|
||||
fout.write(ti_data)
|
||||
fout.flush()
|
||||
self.state = WriterState.TI_DATA
|
||||
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||
if key in self.kv_data:
|
||||
if any(key in kv_data for kv_data in self.kv_data):
|
||||
raise ValueError(f'Duplicated key name {key!r}')
|
||||
|
||||
self.kv_data[key] = GGUFValue(value=val, type=vtype)
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
|
||||
|
||||
def add_uint8(self, key: str, val: int) -> None:
|
||||
self.add_key_value(key,val, GGUFValueType.UINT8)
|
||||
|
@ -206,9 +261,6 @@ class GGUFWriter:
|
|||
self.add_key_value(key, val, GGUFValueType.STRING)
|
||||
|
||||
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
||||
if not isinstance(val, Sequence):
|
||||
raise ValueError("Value must be a sequence for array type")
|
||||
|
||||
self.add_key_value(key, val, GGUFValueType.ARRAY)
|
||||
|
||||
@staticmethod
|
||||
|
@ -222,7 +274,7 @@ class GGUFWriter:
|
|||
if self.state is not WriterState.NO_FILE:
|
||||
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
|
||||
|
||||
if name in self.tensors:
|
||||
if any(name in tensors for tensors in self.tensors):
|
||||
raise ValueError(f'Duplicated tensor name {name!r}')
|
||||
|
||||
if raw_dtype is None:
|
||||
|
@ -247,7 +299,18 @@ class GGUFWriter:
|
|||
if tensor_dtype == np.uint8:
|
||||
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
||||
|
||||
self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
|
||||
# make sure there is at least one tensor before splitting
|
||||
if len(self.tensors[-1]) > 0:
|
||||
if ( # split when over tensor limit
|
||||
self.split_max_tensors != 0
|
||||
and len(self.tensors[-1]) >= self.split_max_tensors
|
||||
) or ( # split when over size limit
|
||||
self.split_max_size != 0
|
||||
and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
|
||||
):
|
||||
self.tensors.append({})
|
||||
|
||||
self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
|
||||
|
||||
def add_tensor(
|
||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||
|
@ -264,7 +327,7 @@ class GGUFWriter:
|
|||
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
|
||||
|
||||
if self.temp_file is None:
|
||||
self.tensors[name].tensor = tensor
|
||||
self.tensors[-1][name].tensor = tensor
|
||||
return
|
||||
|
||||
tensor.tofile(self.temp_file)
|
||||
|
@ -282,9 +345,24 @@ class GGUFWriter:
|
|||
|
||||
if self.endianess == GGUFEndian.BIG:
|
||||
tensor.byteswap(inplace=True)
|
||||
self.write_padding(self.fout, self.fout.tell())
|
||||
tensor.tofile(self.fout)
|
||||
self.write_padding(self.fout, tensor.nbytes)
|
||||
|
||||
file_id = -1
|
||||
for i, tensors in enumerate(self.tensors):
|
||||
if len(tensors) > 0:
|
||||
file_id = i
|
||||
break
|
||||
|
||||
fout = self.fout[file_id]
|
||||
|
||||
# pop the first tensor info
|
||||
# TODO: cleaner way to get the first key
|
||||
first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
|
||||
ti = self.tensors[file_id].pop(first_tensor_name)
|
||||
assert ti.nbytes == tensor.nbytes
|
||||
|
||||
self.write_padding(fout, fout.tell())
|
||||
tensor.tofile(fout)
|
||||
self.write_padding(fout, tensor.nbytes)
|
||||
|
||||
self.state = WriterState.WEIGHTS
|
||||
|
||||
|
@ -293,31 +371,43 @@ class GGUFWriter:
|
|||
|
||||
assert self.fout is not None
|
||||
|
||||
self.write_padding(self.fout, self.fout.tell())
|
||||
for fout in self.fout:
|
||||
self.write_padding(fout, fout.tell())
|
||||
|
||||
if self.temp_file is None:
|
||||
shard_bar = None
|
||||
bar = None
|
||||
|
||||
if progress:
|
||||
from tqdm import tqdm
|
||||
|
||||
total_bytes = sum(t.nbytes for t in self.tensors.values())
|
||||
total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
|
||||
|
||||
if len(self.fout) > 1:
|
||||
shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
|
||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||
|
||||
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||
for ti in self.tensors.values():
|
||||
assert ti.tensor is not None # can only iterate once over the tensors
|
||||
assert ti.tensor.nbytes == ti.nbytes
|
||||
ti.tensor.tofile(self.fout)
|
||||
if bar is not None:
|
||||
bar.update(ti.nbytes)
|
||||
self.write_padding(self.fout, ti.nbytes)
|
||||
ti.tensor = None
|
||||
for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
|
||||
if shard_bar is not None:
|
||||
shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
|
||||
total = sum(ti.nbytes for ti in tensors.values())
|
||||
shard_bar.reset(total=(total if total > 0 else None))
|
||||
|
||||
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||
for ti in tensors.values():
|
||||
assert ti.tensor is not None # can only iterate once over the tensors
|
||||
assert ti.tensor.nbytes == ti.nbytes
|
||||
ti.tensor.tofile(fout)
|
||||
if shard_bar is not None:
|
||||
shard_bar.update(ti.nbytes)
|
||||
if bar is not None:
|
||||
bar.update(ti.nbytes)
|
||||
self.write_padding(fout, ti.nbytes)
|
||||
ti.tensor = None
|
||||
else:
|
||||
self.temp_file.seek(0)
|
||||
|
||||
shutil.copyfileobj(self.temp_file, self.fout)
|
||||
shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
|
||||
self.flush()
|
||||
self.temp_file.close()
|
||||
|
||||
|
@ -325,11 +415,13 @@ class GGUFWriter:
|
|||
|
||||
def flush(self) -> None:
|
||||
assert self.fout is not None
|
||||
self.fout.flush()
|
||||
for fout in self.fout:
|
||||
fout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
if self.fout is not None:
|
||||
self.fout.close()
|
||||
for fout in self.fout:
|
||||
fout.close()
|
||||
self.fout = None
|
||||
|
||||
def add_architecture(self) -> None:
|
||||
|
@ -400,6 +492,9 @@ class GGUFWriter:
|
|||
def add_parallel_residual(self, use: bool) -> None:
|
||||
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
||||
|
||||
def add_decoder_start_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
|
||||
|
||||
def add_head_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
|
||||
|
||||
|
@ -448,6 +543,9 @@ class GGUFWriter:
|
|||
def add_kv_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_relative_attn_buckets_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
|
||||
|
||||
def add_pooling_type(self, value: PoolingType) -> None:
|
||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
|
@ -538,6 +636,12 @@ class GGUFWriter:
|
|||
def add_add_space_prefix(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
||||
|
||||
def add_remove_extra_whitespaces(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
|
||||
|
||||
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
|
||||
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
|
||||
|
||||
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||
if not isinstance(value, str):
|
||||
template_default = None
|
||||
|
@ -599,9 +703,12 @@ class GGUFWriter:
|
|||
kv_data += self._pack("Q", len(encoded_val))
|
||||
kv_data += encoded_val
|
||||
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
||||
ltype = GGUFValueType.get_type(val[0])
|
||||
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
||||
raise ValueError("All items in a GGUF array should be of the same type")
|
||||
if isinstance(val, bytes):
|
||||
ltype = GGUFValueType.UINT8
|
||||
else:
|
||||
ltype = GGUFValueType.get_type(val[0])
|
||||
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
||||
raise ValueError("All items in a GGUF array should be of the same type")
|
||||
kv_data += self._pack("I", ltype)
|
||||
kv_data += self._pack("Q", len(val))
|
||||
for item in val:
|
||||
|
@ -611,6 +718,13 @@ class GGUFWriter:
|
|||
|
||||
return kv_data
|
||||
|
||||
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
|
||||
assert self.fout is not None
|
||||
self.fout.write(self._pack(fmt, value, skip_pack_prefix))
|
||||
@staticmethod
|
||||
def format_n_bytes_to_str(num: int) -> str:
|
||||
if num == 0:
|
||||
return "negligible - metadata only"
|
||||
fnum = float(num)
|
||||
for unit in ("", "K", "M", "G"):
|
||||
if abs(fnum) < 1000.0:
|
||||
return f"{fnum:3.1f}{unit}"
|
||||
fnum /= 1000.0
|
||||
return f"{fnum:.1f}T - over 1TB, split recommended"
|
||||
|
|
|
@ -24,6 +24,7 @@ class TensorNameMap:
|
|||
"backbone.embedding", # mamba
|
||||
"backbone.embeddings", # mamba-hf
|
||||
"transformer.in_out_embed", # Grok
|
||||
"shared", # t5
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
|
@ -421,6 +422,120 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.FFN_SUB_NORM: (
|
||||
"model.layers.{bid}.mlp.ffn_layernorm", # bitnet
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_ATTN_NORM: (
|
||||
"decoder.block.{bid}.layer.0.layer_norm", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_ATTN_Q: (
|
||||
"decoder.block.{bid}.layer.0.SelfAttention.q", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_ATTN_K: (
|
||||
"decoder.block.{bid}.layer.0.SelfAttention.k", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_ATTN_V: (
|
||||
"decoder.block.{bid}.layer.0.SelfAttention.v", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_ATTN_OUT: (
|
||||
"decoder.block.{bid}.layer.0.SelfAttention.o", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_ATTN_REL_B: (
|
||||
"decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_NORM: (
|
||||
"decoder.block.{bid}.layer.1.layer_norm", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_Q: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.q", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_K: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.k", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_V: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.v", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_OUT: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.o", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_FFN_NORM: (
|
||||
"decoder.block.{bid}.layer.2.layer_norm", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_FFN_GATE: (
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_FFN_UP: (
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wi", # t5
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_FFN_DOWN: (
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.DEC_OUTPUT_NORM: (
|
||||
"decoder.final_layer_norm", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_ATTN_NORM: (
|
||||
"encoder.block.{bid}.layer.0.layer_norm", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_ATTN_Q: (
|
||||
"encoder.block.{bid}.layer.0.SelfAttention.q", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_ATTN_K: (
|
||||
"encoder.block.{bid}.layer.0.SelfAttention.k", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_ATTN_V: (
|
||||
"encoder.block.{bid}.layer.0.SelfAttention.v", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_ATTN_OUT: (
|
||||
"encoder.block.{bid}.layer.0.SelfAttention.o", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_ATTN_REL_B: (
|
||||
"encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_FFN_NORM: (
|
||||
"encoder.block.{bid}.layer.1.layer_norm", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_FFN_GATE: (
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_FFN_UP: (
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wi", # t5
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_FFN_DOWN: (
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: (
|
||||
"encoder.final_layer_norm", # t5
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
|
|
|
@ -208,7 +208,9 @@ def translate_tensor_name(name):
|
|||
'ssm_d': 'State space model skip connection',
|
||||
'ssm_dt': 'State space model time step',
|
||||
'ssm_out': 'State space model output projection',
|
||||
'blk': 'Block'
|
||||
'blk': 'Block',
|
||||
'enc': 'Encoder',
|
||||
'dec': 'Decoder',
|
||||
}
|
||||
|
||||
expanded_words = []
|
||||
|
@ -291,6 +293,10 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
|
|||
tensor_group_name = "base"
|
||||
if tensor_components[0] == 'blk':
|
||||
tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}"
|
||||
elif tensor_components[0] in ['enc', 'dec'] and tensor_components[1] == 'blk':
|
||||
tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}.{tensor_components[2]}"
|
||||
elif tensor_components[0] in ['enc', 'dec']:
|
||||
tensor_group_name = f"{tensor_components[0]}"
|
||||
|
||||
# Check if new Tensor Group
|
||||
if tensor_group_name not in tensor_groups:
|
||||
|
|
|
@ -12785,12 +12785,6 @@ static int llama_decode_internal(
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// print timing information per ggml operation (for debugging purposes)
|
||||
// requires GGML_PERF to be defined
|
||||
ggml_graph_print(gf);
|
||||
#endif
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
|
|
43
sgemm.cpp
43
sgemm.cpp
|
@ -249,9 +249,8 @@ class tinyBLAS {
|
|||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n, int task) {
|
||||
if (task == GGML_TASK_TYPE_COMPUTE)
|
||||
mnpack(0, m, 0, n);
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -458,9 +457,8 @@ class tinyBLAS_Q0_ARM {
|
|||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n, int task) {
|
||||
if (task == GGML_TASK_TYPE_COMPUTE)
|
||||
mnpack(0, m, 0, n);
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -596,9 +594,8 @@ class tinyBLAS_Q0_AVX {
|
|||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n, int task) {
|
||||
if (task == GGML_TASK_TYPE_COMPUTE)
|
||||
mnpack(0, m, 0, n);
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -829,7 +826,7 @@ class tinyBLAS_Q0_AVX {
|
|||
* For example, for single-threaded single-precision GEMM you can say
|
||||
*
|
||||
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
|
||||
* 0, 1, GGML_TASK_TYPE_COMPUTE,
|
||||
* 0, 1,
|
||||
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
|
||||
*
|
||||
* @param m is rows in `A` and `C`
|
||||
|
@ -843,14 +840,13 @@ class tinyBLAS_Q0_AVX {
|
|||
* @param ldc is row stride of `C`
|
||||
* @param ith is thread id (must be less than `nth`)
|
||||
* @param nth is number of threads (must be greater than zero)
|
||||
* @param task is GGML task type
|
||||
* @param Atype is GGML data type of `A`
|
||||
* @param Btype is GGML data type of `B`
|
||||
* @param Ctype is GGML data type of `C`
|
||||
* @return true if this function was able to service the matmul request
|
||||
*/
|
||||
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
||||
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
|
||||
int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
|
||||
|
||||
assert(m >= 0);
|
||||
assert(n >= 0);
|
||||
|
@ -877,7 +873,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const float *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__AVX__) || defined(__AVX2__)
|
||||
if (k % 8)
|
||||
|
@ -887,7 +883,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const float *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__ARM_NEON)
|
||||
if (n < 4)
|
||||
|
@ -899,7 +895,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const float *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
|
@ -917,7 +913,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const float *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
|
||||
if (k % 8)
|
||||
|
@ -929,7 +925,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const float *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
||||
if (n < 8)
|
||||
|
@ -943,7 +939,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const ggml_fp16_t *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
||||
if (k % 4)
|
||||
|
@ -955,7 +951,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const float *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
|
@ -971,7 +967,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const block_q8_0 *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__ARM_FEATURE_DOTPROD)
|
||||
tinyBLAS_Q0_ARM<block_q8_0> tb{
|
||||
|
@ -979,7 +975,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const block_q8_0 *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
|
@ -995,7 +991,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const block_q8_0 *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__ARM_FEATURE_DOTPROD)
|
||||
tinyBLAS_Q0_ARM<block_q4_0> tb{
|
||||
|
@ -1003,7 +999,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(const block_q8_0 *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n, task);
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
|
@ -1025,7 +1021,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
(void)ldc;
|
||||
(void)ith;
|
||||
(void)nth;
|
||||
(void)task;
|
||||
(void)Atype;
|
||||
(void)Btype;
|
||||
(void)Ctype;
|
||||
|
|
2
sgemm.h
2
sgemm.h
|
@ -7,7 +7,7 @@ extern "C" {
|
|||
|
||||
bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
|
||||
const void *, int64_t, void *, int64_t, int, int,
|
||||
int, int, int, int);
|
||||
int, int, int);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue