Merge branch 'master' into python-flake8

This commit is contained in:
Galunid 2023-11-18 21:18:17 +01:00
commit 1a738a5d1d
29 changed files with 508 additions and 457 deletions

View file

@ -574,8 +574,12 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GE
endif() endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected") message(STATUS "PowerPC detected")
add_compile_options(-mcpu=native -mtune=native) if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) add_compile_options(-mcpu=powerpc64le)
else()
add_compile_options(-mcpu=native -mtune=native)
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
endif()
else() else()
message(STATUS "Unknown architecture") message(STATUS "Unknown architecture")
endif() endif()

View file

@ -342,6 +342,12 @@ ifneq ($(filter ppc64%,$(UNAME_M)),)
endif endif
endif endif
ifneq ($(filter ppc64le%,$(UNAME_M)),)
MK_CFLAGS += -mcpu=powerpc64le
MK_CXXFLAGS += -mcpu=powerpc64le
CUDA_POWER_ARCH = 1
endif
else else
MK_CFLAGS += -march=rv64gcv -mabi=lp64d MK_CFLAGS += -march=rv64gcv -mabi=lp64d
MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d
@ -392,6 +398,8 @@ else
endif #LLAMA_CUDA_NVCC endif #LLAMA_CUDA_NVCC
ifdef CUDA_DOCKER_ARCH ifdef CUDA_DOCKER_ARCH
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH) NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
else ifdef CUDA_POWER_ARCH
NVCCFLAGS +=
else else
NVCCFLAGS += -arch=native NVCCFLAGS += -arch=native
endif # CUDA_DOCKER_ARCH endif # CUDA_DOCKER_ARCH

View file

@ -1072,6 +1072,12 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
return result; return result;
} }
bool llama_should_add_bos_token(const llama_model * model) {
const int add_bos = llama_add_bos_token(model);
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
}
// //
// YAML utils // YAML utils
// //
@ -1188,6 +1194,7 @@ void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const cha
if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) { if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) {
data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); data_str = std::regex_replace(data_str, std::regex("\n"), "\\n");
data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); data_str = std::regex_replace(data_str, std::regex("\""), "\\\"");
data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)");
data_str = "\"" + data_str + "\""; data_str = "\"" + data_str + "\"";
fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
return; return;

View file

@ -200,6 +200,10 @@ std::string llama_detokenize_bpe(
llama_context * ctx, llama_context * ctx,
const std::vector<llama_token> & tokens); const std::vector<llama_token> & tokens);
// Uses the value from the model metadata if possible, otherwise
// defaults to true when model type is SPM, otherwise false.
bool llama_should_add_bos_token(const llama_model * model);
// //
// YAML utils // YAML utils
// //

View file

@ -1136,6 +1136,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " -ngl N, --n-gpu-layers N Number of model layers to offload to GPU (default %d)", params->n_gpu_layers);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -1355,6 +1356,17 @@ bool consume_common_train_arg(
return true; return true;
} }
params->adam_gclip = std::stof(argv[i]); params->adam_gclip = std::stof(argv[i]);
} else if (arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params->n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
params->print_usage = true; params->print_usage = true;
return true; return true;

View file

@ -1,317 +0,0 @@
#!/usr/bin/env python3
# HF baichuan --> gguf conversion
from __future__ import annotations
import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
import itertools
import numpy as np
import torch
from sentencepiece import SentencePieceProcessor # type: ignore[import]
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
if TYPE_CHECKING:
from typing import TypeAlias
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
# reverse HF permute back to original pth layout
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: int | None = None) -> NDArray:
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
def reverse_hf_permute_part(weights: NDArray, n_part: int, n_head: int, n_head_kv: int| None = None) -> NDArray:
r = weights.shape[0] // 3
return (reverse_hf_permute(weights[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
def reverse_hf_part(weights: NDArray, n_part: int) -> NDArray:
r = weights.shape[0] // 3
return weights[r * n_part : r * n_part + r, ...]
def count_model_parts(dir_model: str) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
num_parts += 1
if num_parts > 0:
print("gguf: found " + str(num_parts) + " model parts")
return num_parts
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a HuggingFace LLaMA model to a GGML compatible file")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
)
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input",
)
parser.add_argument(
"model", type=Path,
help="directory containing model file, or model file itself (*.bin)",
)
parser.add_argument(
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
help="output format - use 0 for float32, 1 for float16",
)
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)
endianess = gguf.GGUFEndian.LITTLE
if args.bigendian:
endianess = gguf.GGUFEndian.BIG
endianess_str = "Big Endian" if args.bigendian else "Little Endian"
print(f"gguf: Conversion Endianess {endianess}")
# possible tensor data types
# ftype == 0 -> float32
# ftype == 1 -> float16
# map from ftype to string
ftype_str = ["f32", "f16"]
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
print("gguf: loading model "+dir_model.name)
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
print("hello print: ",hparams["architectures"][0])
if hparams["architectures"][0] != "BaichuanForCausalLM" and hparams["architectures"][0] != "BaiChuanForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])
sys.exit()
# get number of model parts
num_parts = count_model_parts(dir_model)
print(f"num_parts:{num_parts}\n")
ARCH=gguf.MODEL_ARCH.BAICHUAN
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
print("gguf: get model metadata")
block_count = hparams["num_hidden_layers"]
head_count = hparams["num_attention_heads"]
if "num_key_value_heads" in hparams:
head_count_kv = hparams["num_key_value_heads"]
else:
head_count_kv = head_count
if "_name_or_path" in hparams:
hf_repo = hparams["_name_or_path"]
else:
hf_repo = ""
if "max_sequence_length" in hparams:
ctx_length = hparams["max_sequence_length"]
elif "max_position_embeddings" in hparams:
ctx_length = hparams["max_position_embeddings"]
elif "model_max_length" in hparams:
ctx_length = hparams["model_max_length"]
else:
print("gguf: can not find ctx length parameter.")
sys.exit()
gguf_writer.add_name(dir_model.name)
gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_tensor_data_layout("Meta AI original pth")
gguf_writer.add_context_length(ctx_length)
gguf_writer.add_embedding_length(hparams["hidden_size"])
gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
gguf_writer.add_head_count(head_count)
gguf_writer.add_head_count_kv(head_count_kv)
gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
if "type" in hparams["rope_scaling"]:
if hparams["rope_scaling"]["type"] == "linear":
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
# TOKENIZATION
print("gguf: get tokenizer metadata")
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []
tokenizer_model_file = dir_model / 'tokenizer.model'
if not tokenizer_model_file.is_file():
print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
sys.exit(1)
# vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab, scores and token types")
tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
vocab_size = hparams.get('vocab_size')
if vocab_size is None:
vocab_size = tokenizer.vocab_size()
for i in range(vocab_size):
text: bytes
score: float
piece = tokenizer.id_to_piece(i)
text = piece.encode("utf-8")
score = tokenizer.get_score(i)
toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3
# toktype = 4 is user-defined = tokens from added_tokens.json
if tokenizer.is_unused(i):
toktype = 5
if tokenizer.is_byte(i):
toktype = 6
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
added_tokens_file = dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
addtokens_json = json.load(f)
print("gguf: get added tokens")
for key in addtokens_json:
tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type
gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, n_vocab = len(tokens))
special_vocab.add_to_gguf(gguf_writer)
# TENSORS
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
# tensor info
print("gguf: get tensor metadata")
if num_parts == 0:
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)
for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
tmp=model_part
for i in range(block_count):
if f"model.layers.{i}.self_attn.W_pack.weight" in model_part:
print(f"Unpacking and permuting layer {i}")
tmp[f"model.layers.{i}.self_attn.q_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count)
tmp[f"model.layers.{i}.self_attn.k_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv)
tmp[f"model.layers.{i}.self_attn.v_proj.weight"]=reverse_hf_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],2)
del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
for name in model_part.keys():
data = model_part[name]
# we don't need these
if name.endswith(".rotary_emb.inv_freq"):
continue
old_dtype = data.dtype
# convert any unsupported data types to float32
if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float32)
data = data.squeeze().numpy()
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(name + " -> " + new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(new_name, data)
print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close()
print(f"gguf: model successfully exported to '{fname_out}'")
print("")

View file

@ -193,7 +193,7 @@ class Model:
return gguf.MODEL_ARCH.MPT return gguf.MODEL_ARCH.MPT
if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return gguf.MODEL_ARCH.BAICHUAN return gguf.MODEL_ARCH.BAICHUAN
if arch == "FalconForCausalLM": if arch in ("FalconForCausalLM", "RWForCausalLM"):
return gguf.MODEL_ARCH.FALCON return gguf.MODEL_ARCH.FALCON
if arch == "GPTBigCodeForCausalLM": if arch == "GPTBigCodeForCausalLM":
return gguf.MODEL_ARCH.STARCODER return gguf.MODEL_ARCH.STARCODER

View file

@ -704,6 +704,7 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
data_base_path=pickle_paths[0][:-4], data_base_path=pickle_paths[0][:-4],
zip_file=zf) zip_file=zf)
model = unpickler.load() model = unpickler.load()
if 'model' in model: model = model['model']
as_dict = dict(model.items()) as_dict = dict(model.items())
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)

View file

@ -24,6 +24,7 @@ else()
add_subdirectory(llama-bench) add_subdirectory(llama-bench)
add_subdirectory(llava) add_subdirectory(llava)
add_subdirectory(main) add_subdirectory(main)
add_subdirectory(tokenize)
add_subdirectory(parallel) add_subdirectory(parallel)
add_subdirectory(perplexity) add_subdirectory(perplexity)
add_subdirectory(quantize) add_subdirectory(quantize)

View file

@ -3,9 +3,7 @@
import argparse import argparse
import gguf import gguf
import os
import struct import struct
import sys
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path

View file

@ -548,35 +548,35 @@ static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, fl
struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max); struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
randomize_tensor_normal(lora->tok_embeddings_a, rnd); randomize_tensor_normal(lora->tok_embeddings_a, rnd);
randomize_tensor_normal(lora->tok_embeddings_b, rnd); ggml_set_zero(lora->tok_embeddings_b);
randomize_tensor_normal(lora->norm_a, rnd); randomize_tensor_normal(lora->norm_a, rnd);
randomize_tensor_normal(lora->norm_b, rnd); ggml_set_zero(lora->norm_b);
randomize_tensor_normal(lora->output_a, rnd); randomize_tensor_normal(lora->output_a, rnd);
randomize_tensor_normal(lora->output_b, rnd); ggml_set_zero(lora->output_b);
for (uint32_t i = 0; i < n_layer; ++i) { for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = lora->layers[i]; auto & layer = lora->layers[i];
randomize_tensor_normal(layer.attention_norm_a, rnd); randomize_tensor_normal(layer.attention_norm_a, rnd);
randomize_tensor_normal(layer.attention_norm_b, rnd); ggml_set_zero(layer.attention_norm_b);
randomize_tensor_normal(layer.wq_a, rnd); randomize_tensor_normal(layer.wq_a, rnd);
randomize_tensor_normal(layer.wq_b, rnd); ggml_set_zero(layer.wq_b);
randomize_tensor_normal(layer.wk_a, rnd); randomize_tensor_normal(layer.wk_a, rnd);
randomize_tensor_normal(layer.wk_b, rnd); ggml_set_zero(layer.wk_b);
randomize_tensor_normal(layer.wv_a, rnd); randomize_tensor_normal(layer.wv_a, rnd);
randomize_tensor_normal(layer.wv_b, rnd); ggml_set_zero(layer.wv_b);
randomize_tensor_normal(layer.wo_a, rnd); randomize_tensor_normal(layer.wo_a, rnd);
randomize_tensor_normal(layer.wo_b, rnd); ggml_set_zero(layer.wo_b);
randomize_tensor_normal(layer.ffn_norm_a, rnd); randomize_tensor_normal(layer.ffn_norm_a, rnd);
randomize_tensor_normal(layer.ffn_norm_b, rnd); ggml_set_zero(layer.ffn_norm_b);
randomize_tensor_normal(layer.w1_a, rnd); randomize_tensor_normal(layer.w1_a, rnd);
randomize_tensor_normal(layer.w1_b, rnd); ggml_set_zero(layer.w1_b);
randomize_tensor_normal(layer.w2_a, rnd); randomize_tensor_normal(layer.w2_a, rnd);
randomize_tensor_normal(layer.w2_b, rnd); ggml_set_zero(layer.w2_b);
randomize_tensor_normal(layer.w3_a, rnd); randomize_tensor_normal(layer.w3_a, rnd);
randomize_tensor_normal(layer.w3_b, rnd); ggml_set_zero(layer.w3_b);
} }
free_random_normal_distribution(rnd); free_random_normal_distribution(rnd);
@ -1460,17 +1460,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
} }
params->n_rank_w3 = std::stoi(argv[i]); params->n_rank_w3 = std::stoi(argv[i]);
params->custom_n_rank_w3 = true; params->custom_n_rank_w3 = true;
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params->common.n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params); train_print_usage(argc, argv, &default_params);

View file

@ -230,7 +230,7 @@ int main(int argc, char ** argv) {
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("%s\n", get_system_info(params).c_str()); LOG_TEE("%s\n", get_system_info(params).c_str());
} }
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);
bool suff_rm_leading_spc = params.escape; bool suff_rm_leading_spc = params.escape;

View file

@ -208,9 +208,10 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
int n_past = 0; int n_past = 0;
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama));
// llava chat format is "<system_prompt>\nUSER:<image_embeddings>\n<textual_prompt>\nASSISTANT:" // llava chat format is "<system_prompt>\nUSER:<image_embeddings>\n<textual_prompt>\nASSISTANT:"
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params->n_batch, &n_past, true); eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params->n_batch, &n_past, add_bos);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past); llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, (prompt + "\nASSISTANT:").c_str(), params->n_batch, &n_past, false); eval_string(ctx_llava->ctx_llama, (prompt + "\nASSISTANT:").c_str(), params->n_batch, &n_past, false);

View file

@ -127,7 +127,14 @@ static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long
fclose(file); fclose(file);
return false; return false;
} }
fread(buffer, 1, fileSize, file); // Read the file into the buffer errno = 0;
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
if (ferror(file)) {
die_fmt("read error: %s", strerror(errno));
}
if (ret != (size_t) fileSize) {
die("unexpectedly reached end of file");
}
fclose(file); // Close the file fclose(file); // Close the file
*bytesOut = buffer; *bytesOut = buffer;

View file

@ -229,7 +229,7 @@ int main(int argc, char ** argv) {
} }
} }
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;

View file

@ -149,8 +149,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const bool add_bos = is_spm;
fprintf(stderr, "%s: tokenizing the input ..\n", __func__); fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -288,8 +287,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const bool add_bos = is_spm;
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now(); auto tim1 = std::chrono::high_resolution_clock::now();
@ -481,7 +479,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
fprintf(stderr, "================================= is_spm = %d\n", is_spm); fprintf(stderr, "================================= is_spm = %d\n", is_spm);
// This is needed as usual for LLaMA models // This is needed as usual for LLaMA models
const bool add_bos = is_spm; const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
// Number of tasks to use when computing the score // Number of tasks to use when computing the score
if ( params.hellaswag_tasks < hs_task_count ) { if ( params.hellaswag_tasks < hs_task_count ) {

View file

@ -501,6 +501,7 @@ struct llama_server_context
bool multimodal = false; bool multimodal = false;
bool clean_kv_cache = true; bool clean_kv_cache = true;
bool all_slots_are_idle = false; bool all_slots_are_idle = false;
bool add_bos_token = true;
int32_t id_gen; int32_t id_gen;
int32_t n_ctx; // total context for all clients / slots int32_t n_ctx; // total context for all clients / slots
@ -573,6 +574,8 @@ struct llama_server_context
n_ctx = llama_n_ctx(ctx); n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_should_add_bos_token(model);
return true; return true;
} }
@ -864,7 +867,7 @@ struct llama_server_context
} }
void update_system_prompt() { void update_system_prompt() {
system_tokens = ::llama_tokenize(ctx, system_prompt, true); system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
llama_batch_clear(batch); llama_batch_clear(batch);
@ -1552,7 +1555,7 @@ struct llama_server_context
} }
else else
{ {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
} }
slot.num_prompt_tokens = prompt_tokens.size(); slot.num_prompt_tokens = prompt_tokens.size();
@ -1629,7 +1632,7 @@ struct llama_server_context
const bool has_images = process_images(slot); const bool has_images = process_images(slot);
// process the prefix of first image // process the prefix of first image
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens; std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
{ {
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false); llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);

View file

@ -0,0 +1,5 @@
set(TARGET tokenize)
add_executable(${TARGET} tokenize.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -0,0 +1,44 @@
#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
int main(int argc, char ** argv) {
if (argc < 3 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH PROMPT [--ids]\n" , argv[0]);
return 1;
}
const char * model_path = argv[1];
const char * prompt = argv[2];
const bool printing_ids = argc > 3 && std::string(argv[3]) == "--ids";
llama_backend_init(false);
llama_model_params model_params = llama_model_default_params();
model_params.vocab_only = true;
llama_model * model = llama_load_model_from_file(model_path, model_params);
llama_context_params ctx_params = llama_context_default_params();
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
const bool add_bos = true;
std::vector<llama_token> tokens;
tokens = ::llama_tokenize(model, prompt, add_bos, true);
for (int i = 0; i < (int) tokens.size(); i++) {
if (printing_ids) {
printf("%d\n", tokens[i]);
} else {
printf("%6d -> '%s'\n", tokens[i], llama_token_to_piece(ctx, tokens[i]).c_str());
}
}
return 0;
}

View file

@ -235,7 +235,7 @@ typedef float2 dfloat2;
#endif //GGML_CUDA_F16 #endif //GGML_CUDA_F16
static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
int x32 = 0; int x32 = 0;
x32 |= x16[0] << 0; x32 |= x16[0] << 0;
@ -245,7 +245,7 @@ static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const
} }
static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
int x32 = 0; int x32 = 0;
x32 |= x16[0] << 0; x32 |= x16[0] << 0;
@ -255,11 +255,11 @@ static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, con
} }
static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
} }
static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
} }
template<typename T> template<typename T>
@ -469,7 +469,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
#define MUL_MAT_SRC1_COL_STRIDE 128 #define MUL_MAT_SRC1_COL_STRIDE 128
#define MAX_STREAMS 8 #define MAX_STREAMS 8
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } };
struct ggml_tensor_extra_gpu { struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@ -2248,6 +2248,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh; (void)x_sc;
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
@ -2259,7 +2260,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0( template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
(void)x_qh; (void)x_sc;
GGML_CUDA_ASSUME(i_offset >= 0); GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps); GGML_CUDA_ASSUME(i_offset < nwarps);
GGML_CUDA_ASSUME(k >= 0); GGML_CUDA_ASSUME(k >= 0);
@ -2268,7 +2269,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = k / QI4_0; const int kbx = k / QI4_0;
const int kqsx = k % QI4_0; const int kqsx = k % QI4_0;
const block_q4_0 * bx0 = (block_q4_0 *) vx; const block_q4_0 * bx0 = (const block_q4_0 *) vx;
float * x_dmf = (float *) x_dm; float * x_dmf = (float *) x_dm;
@ -2306,9 +2307,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh; (void)x_sc;
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
const float * x_dmf = (float *) x_dm; const float * x_dmf = (const float *) x_dm;
int u[2*VDR_Q4_0_Q8_1_MMQ]; int u[2*VDR_Q4_0_Q8_1_MMQ];
@ -2342,6 +2344,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh; (void)x_sc;
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
@ -2353,6 +2356,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1( template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
(void)x_qh; (void)x_sc;
GGML_CUDA_ASSUME(i_offset >= 0); GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps); GGML_CUDA_ASSUME(i_offset < nwarps);
@ -2362,7 +2366,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = k / QI4_1; const int kbx = k / QI4_1;
const int kqsx = k % QI4_1; const int kqsx = k % QI4_1;
const block_q4_1 * bx0 = (block_q4_1 *) vx; const block_q4_1 * bx0 = (const block_q4_1 *) vx;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@ -2397,6 +2401,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh; (void)x_sc;
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
@ -2434,6 +2439,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh; (void)x_sc;
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
@ -2445,6 +2451,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0( template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
(void)x_qh; (void)x_sc;
GGML_CUDA_ASSUME(i_offset >= 0); GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps); GGML_CUDA_ASSUME(i_offset < nwarps);
@ -2454,7 +2461,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = k / QI5_0; const int kbx = k / QI5_0;
const int kqsx = k % QI5_0; const int kqsx = k % QI5_0;
const block_q5_0 * bx0 = (block_q5_0 *) vx; const block_q5_0 * bx0 = (const block_q5_0 *) vx;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@ -2509,6 +2516,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh; (void)x_sc;
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
@ -2548,6 +2556,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh; (void)x_sc;
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
@ -2559,6 +2568,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1( template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
(void)x_qh; (void)x_sc;
GGML_CUDA_ASSUME(i_offset >= 0); GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps); GGML_CUDA_ASSUME(i_offset < nwarps);
@ -2568,7 +2578,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = k / QI5_1; const int kbx = k / QI5_1;
const int kqsx = k % QI5_1; const int kqsx = k % QI5_1;
const block_q5_1 * bx0 = (block_q5_1 *) vx; const block_q5_1 * bx0 = (const block_q5_1 *) vx;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@ -2620,6 +2630,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh; (void)x_sc;
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
@ -2654,6 +2665,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh; (void)x_sc;
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
@ -2665,6 +2677,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0( template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
(void)x_qh; (void)x_sc;
GGML_CUDA_ASSUME(i_offset >= 0); GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps); GGML_CUDA_ASSUME(i_offset < nwarps);
@ -2675,7 +2688,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kqsx = k % QI8_0; const int kqsx = k % QI8_0;
float * x_dmf = (float *) x_dm; float * x_dmf = (float *) x_dm;
const block_q8_0 * bx0 = (block_q8_0 *) vx; const block_q8_0 * bx0 = (const block_q8_0 *) vx;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@ -2710,6 +2723,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh; (void)x_sc;
const float * x_dmf = (const float *) x_dm; const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds; const float * y_df = (const float *) y_ds;
@ -2743,6 +2757,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh;
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
@ -2756,6 +2771,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K( template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
(void)x_qh;
GGML_CUDA_ASSUME(i_offset >= 0); GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps); GGML_CUDA_ASSUME(i_offset < nwarps);
@ -2765,7 +2781,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = k / QI2_K; const int kbx = k / QI2_K;
const int kqsx = k % QI2_K; const int kqsx = k % QI2_K;
const block_q2_K * bx0 = (block_q2_K *) vx; const block_q2_K * bx0 = (const block_q2_K *) vx;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@ -2813,6 +2829,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh;
const int kbx = k / QI2_K; const int kbx = k / QI2_K;
const int ky = (k % QI2_K) * QR2_K; const int ky = (k % QI2_K) * QR2_K;
@ -2886,7 +2903,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = k / QI3_K; const int kbx = k / QI3_K;
const int kqsx = k % QI3_K; const int kqsx = k % QI3_K;
const block_q3_K * bx0 = (block_q3_K *) vx; const block_q3_K * bx0 = (const block_q3_K *) vx;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@ -2967,7 +2984,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
const float * x_dmf = (const float *) x_dm; const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds; const float * y_df = (const float *) y_ds;
const int8_t * scales = ((int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
@ -3082,6 +3099,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh;
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
@ -3095,6 +3113,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K( template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
(void)x_qh;
GGML_CUDA_ASSUME(i_offset >= 0); GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps); GGML_CUDA_ASSUME(i_offset < nwarps);
@ -3104,7 +3123,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = k / QI4_K; // == 0 if QK_K == 256 const int kbx = k / QI4_K; // == 0 if QK_K == 256
const int kqsx = k % QI4_K; // == k if QK_K == 256 const int kqsx = k % QI4_K; // == k if QK_K == 256
const block_q4_K * bx0 = (block_q4_K *) vx; const block_q4_K * bx0 = (const block_q4_K *) vx;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@ -3149,7 +3168,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
const int * scales = (int *) bxi->scales; const int * scales = (const int *) bxi->scales;
const int ksc = k % (WARP_SIZE/8); const int ksc = k % (WARP_SIZE/8);
@ -3164,6 +3183,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh;
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
@ -3263,6 +3283,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh;
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
@ -3276,6 +3297,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K( template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
(void)x_qh;
GGML_CUDA_ASSUME(i_offset >= 0); GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps); GGML_CUDA_ASSUME(i_offset < nwarps);
@ -3285,7 +3307,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = k / QI5_K; // == 0 if QK_K == 256 const int kbx = k / QI5_K; // == 0 if QK_K == 256
const int kqsx = k % QI5_K; // == k if QK_K == 256 const int kqsx = k % QI5_K; // == k if QK_K == 256
const block_q5_K * bx0 = (block_q5_K *) vx; const block_q5_K * bx0 = (const block_q5_K *) vx;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@ -3341,7 +3363,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
const int * scales = (int *) bxi->scales; const int * scales = (const int *) bxi->scales;
const int ksc = k % (WARP_SIZE/8); const int ksc = k % (WARP_SIZE/8);
@ -3356,6 +3378,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh;
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
@ -3392,6 +3415,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh;
__shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
@ -3405,6 +3429,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K( template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
(void)x_qh;
GGML_CUDA_ASSUME(i_offset >= 0); GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps); GGML_CUDA_ASSUME(i_offset < nwarps);
@ -3414,7 +3439,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = k / QI6_K; // == 0 if QK_K == 256 const int kbx = k / QI6_K; // == 0 if QK_K == 256
const int kqsx = k % QI6_K; // == k if QK_K == 256 const int kqsx = k % QI6_K; // == k if QK_K == 256
const block_q6_K * bx0 = (block_q6_K *) vx; const block_q6_K * bx0 = (const block_q6_K *) vx;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@ -3476,6 +3501,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
(void)x_qh;
const float * x_dmf = (const float *) x_dm; const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds; const float * y_df = (const float *) y_ds;
@ -3518,7 +3544,7 @@ static __device__ __forceinline__ void mul_mat_q(
__shared__ int tile_y_qs[mmq_x * WARP_SIZE]; __shared__ int tile_y_qs[mmq_x * WARP_SIZE];
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {0.0f}; float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
@ -5840,7 +5866,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
return ptr; return ptr;
} }
#ifdef DEBUG_CUDA_MALLOC #ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz, fprintf(stderr, "%s: %d buffers, max_size = %u MiB, tot_size = %u MiB, requested %u MiB\n", __func__, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024)); (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
#endif #endif
void * ptr; void * ptr;
@ -5978,7 +6004,7 @@ void * ggml_cuda_host_malloc(size_t size) {
// The allocation error can be bypassed. A null ptr will assigned out of this function. // The allocation error can be bypassed. A null ptr will assigned out of this function.
// This can fixed the OOM error in WSL. // This can fixed the OOM error in WSL.
cudaGetLastError(); cudaGetLastError();
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", fprintf(stderr, "WARNING: failed to allocate %.2f MiB of pinned memory: %s\n",
size/1024.0/1024.0, cudaGetErrorString(err)); size/1024.0/1024.0, cudaGetErrorString(err));
return nullptr; return nullptr;
} }
@ -6023,18 +6049,18 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
if (nb0 == ts && nb1 == ts*ne0/bs) { if (nb0 == ts && nb1 == ts*ne0/bs) {
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
} else if (nb0 == ts) {
return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
} else {
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
// pretend the row is a matrix with cols=1
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
if (r != cudaSuccess) return r;
}
return cudaSuccess;
} }
if (nb0 == ts) {
return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
}
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
// pretend the row is a matrix with cols=1
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
if (r != cudaSuccess) { return r; }
}
return cudaSuccess;
} }
static void ggml_cuda_op_repeat( static void ggml_cuda_op_repeat(
@ -6356,6 +6382,7 @@ static int64_t get_row_rounding(ggml_type type) {
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return max_compute_capability >= CC_RDNA2 ? 128 : 64; return max_compute_capability >= CC_RDNA2 ? 128 : 64;
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1; return 1;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
return max_compute_capability >= CC_RDNA2 ? 128 : 32; return max_compute_capability >= CC_RDNA2 ? 128 : 32;
@ -6378,6 +6405,7 @@ static int64_t get_row_rounding(ggml_type type) {
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return 64; return 64;
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1; return 1;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
@ -6987,7 +7015,7 @@ static void ggml_cuda_op_mul_mat(
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3]; const int64_t ne03 = src0->ne[3];
const int64_t nrows0 = ggml_nrows(src0); // const int64_t nrows0 = ggml_nrows(src0);
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1]; const int64_t ne11 = src1->ne[1];
@ -7088,7 +7116,7 @@ static void ggml_cuda_op_mul_mat(
if (src0_on_device && src0_is_contiguous) { if (src0_on_device && src0_is_contiguous) {
src0_dd[id] = (char *) src0_extra->data_device[id]; src0_dd[id] = (char *) src0_extra->data_device[id];
} else { } else {
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); // const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
} }
@ -7321,7 +7349,7 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
} }
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
if (!g_cublas_loaded) return false; if (!g_cublas_loaded) { return false; }
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
@ -7399,7 +7427,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
} }
__global__ void k_compute_batched_ptrs( __global__ static void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
const void ** ptrs_src, void ** ptrs_dst, const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13, int ne12, int ne13,
@ -8015,7 +8043,7 @@ void ggml_cuda_free_scratch() {
} }
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
if (!g_cublas_loaded) return false; if (!g_cublas_loaded) { return false; }
ggml_cuda_func_t func; ggml_cuda_func_t func;
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
@ -8314,14 +8342,14 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
UNUSED(cgraph); UNUSED(cgraph);
} }
static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { [[noreturn]] static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_ASSERT(!"not implemented"); GGML_ASSERT(!"not implemented");
UNUSED(backend); UNUSED(backend);
UNUSED(plan); UNUSED(plan);
} }
static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { [[noreturn]] static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_ASSERT(!"not implemented"); GGML_ASSERT(!"not implemented");
UNUSED(backend); UNUSED(backend);
@ -8337,8 +8365,9 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i]; ggml_tensor * node = cgraph->nodes[i];
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) {
continue; continue;
}
assert(node->backend == GGML_BACKEND_GPU); assert(node->backend == GGML_BACKEND_GPU);
for (int j = 0; j < GGML_MAX_SRC; j++) { for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) { if (node->src[j] != nullptr) {

View file

@ -345,10 +345,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
} }
} }
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.maxTransferRate != 0) { if (ctx->device.maxTransferRate != 0) {
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
} else { } else {
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__); GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
} }
@ -541,11 +541,11 @@ bool ggml_metal_add_buffer(
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) { if (ctx->buffers[ctx->n_buffers].metal == nil) {
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0); GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
return false; return false;
} }
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0); GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
++ctx->n_buffers; ++ctx->n_buffers;
} else { } else {
@ -565,11 +565,11 @@ bool ggml_metal_add_buffer(
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) { if (ctx->buffers[ctx->n_buffers].metal == nil) {
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
return false; return false;
} }
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
if (i + size_step < size) { if (i + size_step < size) {
GGML_METAL_LOG_INFO("\n"); GGML_METAL_LOG_INFO("\n");
} }

View file

@ -19,7 +19,7 @@
#ifdef __wasm_simd128__ #ifdef __wasm_simd128__
#include <wasm_simd128.h> #include <wasm_simd128.h>
#else #else
#ifdef __POWER9_VECTOR__ #if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
#include <altivec.h> #include <altivec.h>
#undef bool #undef bool
#define bool _Bool #define bool _Bool

94
ggml.c
View file

@ -9611,10 +9611,12 @@ static void ggml_compute_forward_out_prod_f32(
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13); GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne03 == ne13);
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float));
@ -9625,18 +9627,25 @@ static void ggml_compute_forward_out_prod_f32(
// GGML_ASSERT(nb1 <= nb2); // GGML_ASSERT(nb1 <= nb2);
// GGML_ASSERT(nb2 <= nb3); // GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
// nb01 >= nb00 - src0 is not transposed // nb01 >= nb00 - src0 is not transposed
// compute by src0 rows // compute by src0 rows
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) // TODO: #if defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
bool use_blas = ggml_is_matrix(src0) &&
ggml_is_matrix(src1) &&
ggml_is_contiguous(src0) &&
(ggml_is_contiguous(src1) || ggml_is_transposed(src1));
#endif
if (params->type == GGML_TASK_INIT) { if (params->type == GGML_TASK_INIT) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
if (use_blas) {
return;
}
#endif
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
return; return;
} }
@ -9645,6 +9654,50 @@ static void ggml_compute_forward_out_prod_f32(
return; return;
} }
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (use_blas) {
if (params->ith != 0) { // All threads other than the first do no work.
return;
}
// Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
// src0: (k,n)
// src1: (k,m)
// dst: (m,n)
//
// Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
// Also expressed as (major,minor)
// a: (m,k): so src1 transposed
// b: (k,n): so src0
// c: (m,n)
//
// However, if ggml_is_transposed(src1) is true, then
// src1->data already contains a transposed version, so sgemm mustn't
// transpose it further.
int n = src0->ne[0];
int k = src0->ne[1];
int m = src1->ne[0];
int transposeA, lda;
if (!ggml_is_transposed(src1)) {
transposeA = CblasTrans;
lda = m;
} else {
transposeA = CblasNoTrans;
lda = k;
}
float * a = (float *) ((char *) src1->data);
float * b = (float *) ((char *) src0->data);
float * c = (float *) ((char *) dst->data);
cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
return;
}
#endif
// dst[:,:,:,:] = 0 // dst[:,:,:,:] = 0
// for i2,i3: // for i2,i3:
// for i1: // for i1:
@ -18399,24 +18452,29 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
} }
const char * gguf_get_key(const struct gguf_context * ctx, int key_id) { const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
return ctx->kv[key_id].key.data; return ctx->kv[key_id].key.data;
} }
enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) { enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
return ctx->kv[key_id].type; return ctx->kv[key_id].type;
} }
enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) { enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
return ctx->kv[key_id].value.arr.type; return ctx->kv[key_id].value.arr.type;
} }
const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) { const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
return ctx->kv[key_id].value.arr.data; return ctx->kv[key_id].value.arr.data;
} }
const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) { const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
struct gguf_kv * kv = &ctx->kv[key_id]; struct gguf_kv * kv = &ctx->kv[key_id];
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
@ -18424,70 +18482,90 @@ const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i
} }
int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) { int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
return ctx->kv[key_id].value.arr.n; return ctx->kv[key_id].value.arr.n;
} }
uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) { uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
return ctx->kv[key_id].value.uint8; return ctx->kv[key_id].value.uint8;
} }
int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) { int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
return ctx->kv[key_id].value.int8; return ctx->kv[key_id].value.int8;
} }
uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) { uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
return ctx->kv[key_id].value.uint16; return ctx->kv[key_id].value.uint16;
} }
int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) { int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
return ctx->kv[key_id].value.int16; return ctx->kv[key_id].value.int16;
} }
uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) { uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
return ctx->kv[key_id].value.uint32; return ctx->kv[key_id].value.uint32;
} }
int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) { int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
return ctx->kv[key_id].value.int32; return ctx->kv[key_id].value.int32;
} }
float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) { float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
return ctx->kv[key_id].value.float32; return ctx->kv[key_id].value.float32;
} }
uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) { uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
return ctx->kv[key_id].value.uint64; return ctx->kv[key_id].value.uint64;
} }
int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) { int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
return ctx->kv[key_id].value.int64; return ctx->kv[key_id].value.int64;
} }
double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) { double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
return ctx->kv[key_id].value.float64; return ctx->kv[key_id].value.float64;
} }
bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) { bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
return ctx->kv[key_id].value.bool_; return ctx->kv[key_id].value.bool_;
} }
const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) { const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
return ctx->kv[key_id].value.str.data; return ctx->kv[key_id].value.str.data;
} }
const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
return &ctx->kv[key_id].value;
}
int gguf_get_n_tensors(const struct gguf_context * ctx) { int gguf_get_n_tensors(const struct gguf_context * ctx) {
return ctx->header.n_tensors; return ctx->header.n_tensors;
} }

1
ggml.h
View file

@ -2045,6 +2045,7 @@ extern "C" {
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);

View file

@ -117,17 +117,18 @@ class SpecialVocab:
def _try_load_from_tokenizer_json(self, path: Path) -> bool: def _try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer_file = path / 'tokenizer.json' tokenizer_file = path / 'tokenizer.json'
if not tokenizer_file.is_file(): if tokenizer_file.is_file():
return False with open(tokenizer_file, encoding = 'utf-8') as f:
with open(tokenizer_file, encoding = 'utf-8') as f: tokenizer = json.load(f)
tokenizer = json.load(f) if self.load_merges:
if self.load_merges: merges = tokenizer.get('model', {}).get('merges')
merges = tokenizer.get('model', {}).get('merges') if isinstance(merges, list) and merges and isinstance(merges[0], str):
if isinstance(merges, list) and merges and isinstance(merges[0], str): self.merges = merges
self.merges = merges added_tokens = tokenizer.get('added_tokens', {})
else:
added_tokens = {}
tokenizer_config_file = path / 'tokenizer_config.json' tokenizer_config_file = path / 'tokenizer_config.json'
added_tokens = tokenizer.get('added_tokens') if not tokenizer_config_file.is_file():
if added_tokens is None or not tokenizer_config_file.is_file():
return True return True
with open(tokenizer_config_file, encoding = 'utf-8') as f: with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f) tokenizer_config = json.load(f)
@ -135,6 +136,10 @@ class SpecialVocab:
add_entry = tokenizer_config.get(f'add_{typ}_token') add_entry = tokenizer_config.get(f'add_{typ}_token')
if isinstance(add_entry, bool): if isinstance(add_entry, bool):
self.add_special_token[typ] = add_entry self.add_special_token[typ] = add_entry
if not added_tokens:
# We will need this to get the content for the token, so if it's empty
# may as well just give up.
continue
entry = tokenizer_config.get(f'{typ}_token') entry = tokenizer_config.get(f'{typ}_token')
if isinstance(entry, str): if isinstance(entry, str):
tc_content = entry tc_content = entry

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "gguf" name = "gguf"
version = "0.5.2" version = "0.5.3"
description = "Read and write ML models in GGUF for GGML" description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [

View file

@ -86,13 +86,14 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8") curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8")
else: else:
curr["value"] = field.parts[-1].tolist()[0] curr["value"] = field.parts[-1].tolist()[0]
for idx, tensor in enumerate(reader.tensors): if not args.no_tensors:
tensors[tensor.name] = { for idx, tensor in enumerate(reader.tensors):
"index": idx, tensors[tensor.name] = {
"shape": tensor.shape.tolist(), "index": idx,
"type": tensor.tensor_type.name, "shape": tensor.shape.tolist(),
"offset": tensor.field.offset, "type": tensor.tensor_type.name,
} "offset": tensor.field.offset,
}
json.dump(result, sys.stdout) json.dump(result, sys.stdout)

197
llama.cpp
View file

@ -91,7 +91,7 @@
#define LLAMA_ATTRIBUTE_FORMAT(...) #define LLAMA_ATTRIBUTE_FORMAT(...)
#endif #endif
#define LLAMA_MAX_NODES 4096 #define LLAMA_MAX_NODES 8192
// //
// logging // logging
@ -255,6 +255,8 @@ enum llm_kv {
LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_UNK_ID,
LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_SEP_ID,
LLM_KV_TOKENIZER_PAD_ID, LLM_KV_TOKENIZER_PAD_ID,
LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_RWKV,
}; };
@ -303,6 +305,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
}; };
@ -600,6 +604,60 @@ static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
return LLAMA_ROPE_SCALING_UNSPECIFIED; return LLAMA_ROPE_SCALING_UNSPECIFIED;
} }
static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
switch (type) {
case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
default: return format("unknown type %d", type);
}
}
static std::string gguf_kv_to_str(struct gguf_context * ctx_gguf, int i) {
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
switch (type) {
case GGUF_TYPE_STRING:
return gguf_get_val_str(ctx_gguf, i);
case GGUF_TYPE_ARRAY:
{
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
int arr_n = gguf_get_arr_n(ctx_gguf, i);
const void * data = gguf_get_arr_data(ctx_gguf, i);
std::stringstream ss;
ss << "[";
for (int j = 0; j < arr_n; j++) {
if (arr_type == GGUF_TYPE_STRING) {
std::string val = gguf_get_arr_str(ctx_gguf, i, j);
// escape quotes
replace_all(val, "\\", "\\\\");
replace_all(val, "\"", "\\\"");
ss << '"' << val << '"';
} else if (arr_type == GGUF_TYPE_ARRAY) {
ss << "???";
} else {
ss << gguf_data_to_str(arr_type, data, j);
}
if (j < arr_n - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}
default:
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
}
}
// //
// ggml helpers // ggml helpers
// //
@ -1083,9 +1141,9 @@ enum e_model {
MODEL_70B, MODEL_70B,
}; };
static const size_t kB = 1024; static const size_t kiB = 1024;
static const size_t MB = 1024*kB; static const size_t MiB = 1024*kiB;
static const size_t GB = 1024*MB; static const size_t GiB = 1024*MiB;
struct llama_hparams { struct llama_hparams {
bool vocab_only; bool vocab_only;
@ -1276,6 +1334,9 @@ struct llama_vocab {
id special_sep_id = -1; id special_sep_id = -1;
id special_pad_id = -1; id special_pad_id = -1;
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
id linefeed_id = 13; id linefeed_id = 13;
id special_prefix_id = 32007; id special_prefix_id = 32007;
id special_middle_id = 32009; id special_middle_id = 32009;
@ -1320,6 +1381,9 @@ struct llama_model {
int n_gpu_layers; int n_gpu_layers;
// gguf metadata
std::unordered_map<std::string, std::string> gguf_kv;
// context // context
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
@ -1481,7 +1545,7 @@ static bool llama_kv_cache_init(
vram_kv_cache += ggml_nbytes(cache.k); vram_kv_cache += ggml_nbytes(cache.k);
} }
if (vram_kv_cache > 0) { if (vram_kv_cache > 0) {
LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MiB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
} }
} }
#endif #endif
@ -1778,10 +1842,10 @@ struct llama_model_loader {
case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break;
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
default: default:
{ {
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
ftype = LLAMA_FTYPE_ALL_F32; ftype = LLAMA_FTYPE_ALL_F32;
} break; } break;
} }
// this is a way to mark that we have "guessed" the file type // this is a way to mark that we have "guessed" the file type
@ -1795,10 +1859,20 @@ struct llama_model_loader {
} }
for (int i = 0; i < n_kv; i++) { for (int i = 0; i < n_kv; i++) {
const char * name = gguf_get_key(ctx_gguf, i); const char * name = gguf_get_key(ctx_gguf, i);
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
const std::string type_name =
type == GGUF_TYPE_ARRAY
? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(ctx_gguf, i)), gguf_get_arr_n(ctx_gguf, i))
: gguf_type_name(type);
LLAMA_LOG_INFO("%s: - kv %3d: %42s %-8s\n", __func__, i, name, gguf_type_name(type)); std::string value = gguf_kv_to_str(ctx_gguf, i);
const size_t MAX_VALUE_LEN = 40;
if (value.size() > MAX_VALUE_LEN) {
value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
}
LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
} }
// print type counts // print type counts
@ -2093,6 +2167,17 @@ static void llm_load_hparams(
auto & hparams = model.hparams; auto & hparams = model.hparams;
// get metadata as string
for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
enum gguf_type type = gguf_get_kv_type(ctx, i);
if (type == GGUF_TYPE_ARRAY) {
continue;
}
const char * name = gguf_get_key(ctx, i);
const std::string value = gguf_kv_to_str(ctx, i);
model.gguf_kv.emplace(name, value);
}
// get general kv // get general kv
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
@ -2388,6 +2473,23 @@ static void llm_load_vocab(
__func__, key.c_str(), id, old_id); __func__, key.c_str(), id, old_id);
id = old_id; id = old_id;
} }
}
// Handle add_bos_token and add_eos_token
std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS);
int kid = gguf_find_key(ctx, key.c_str());
enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
}
key = kv(LLM_KV_TOKENIZER_ADD_EOS);
kid = gguf_find_key(ctx, key.c_str());
ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
} }
} }
@ -2519,8 +2621,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
if (ml.n_bytes < GB) { if (ml.n_bytes < GiB) {
LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
} else { } else {
LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
} }
@ -2558,7 +2660,7 @@ static void llm_load_tensors(
ml.calc_sizes(ctx_size, mmapped_size); ml.calc_sizes(ctx_size, mmapped_size);
LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, ctx_size/1024.0/1024.0);
// create the ggml context // create the ggml context
{ {
@ -3207,7 +3309,7 @@ static void llm_load_tensors(
ctx_size + ctx_size +
mmapped_size - vram_weights; // weights in VRAM not in memory mmapped_size - vram_weights; // weights in VRAM not in memory
LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: mem required = %7.2f MiB\n", __func__, mem_required / 1024.0 / 1024.0);
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
@ -3226,7 +3328,7 @@ static void llm_load_tensors(
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: VRAM used: %.2f MiB\n", __func__, vram_weights / 1024.0 / 1024.0);
#else #else
(void) n_gpu_layers; (void) n_gpu_layers;
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
@ -7938,7 +8040,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
workers.clear(); workers.clear();
} }
LLAMA_LOG_INFO("size = %8.2f MB -> %8.2f MB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
int64_t tot_count = 0; int64_t tot_count = 0;
for (size_t i = 0; i < hist_cur.size(); i++) { for (size_t i = 0; i < hist_cur.size(); i++) {
hist_all[i] += hist_cur[i]; hist_all[i] += hist_cur[i];
@ -8478,7 +8580,7 @@ struct llama_context * llama_new_context_with_model(
{ {
const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v);
LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: kv self size = %7.2f MiB\n", __func__, memory_size / 1024.0 / 1024.0);
} }
// resized during inference // resized during inference
@ -8523,7 +8625,7 @@ struct llama_context * llama_new_context_with_model(
// measure memory requirements for the graph // measure memory requirements for the graph
size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MiB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
// recreate allocator with exact memory requirements // recreate allocator with exact memory requirements
ggml_allocr_free(ctx->alloc); ggml_allocr_free(ctx->alloc);
@ -8537,7 +8639,7 @@ struct llama_context * llama_new_context_with_model(
#endif #endif
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
ggml_cuda_set_scratch_size(alloc_size); ggml_cuda_set_scratch_size(alloc_size);
LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MiB\n", __func__, alloc_size / 1024.0 / 1024.0);
// calculate total VRAM usage // calculate total VRAM usage
auto add_tensor = [](const ggml_tensor * t, size_t & size) { auto add_tensor = [](const ggml_tensor * t, size_t & size) {
@ -8557,10 +8659,10 @@ struct llama_context * llama_new_context_with_model(
size_t ctx_vram_size = alloc_size + kv_vram_size; size_t ctx_vram_size = alloc_size + kv_vram_size;
size_t total_vram_size = model_vram_size + ctx_vram_size; size_t total_vram_size = model_vram_size + ctx_vram_size;
LLAMA_LOG_INFO("%s: total VRAM used: %.2f MB (model: %.2f MB, context: %.2f MB)\n", __func__, LLAMA_LOG_INFO("%s: total VRAM used: %.2f MiB (model: %.2f MiB, context: %.2f MiB)\n", __func__,
total_vram_size / 1024.0 / 1024.0, total_vram_size / 1024.0 / 1024.0,
model_vram_size / 1024.0 / 1024.0, model_vram_size / 1024.0 / 1024.0,
ctx_vram_size / 1024.0 / 1024.0); ctx_vram_size / 1024.0 / 1024.0);
#endif #endif
} }
@ -8581,7 +8683,7 @@ struct llama_context * llama_new_context_with_model(
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
LLAMA_LOG_INFO("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); LLAMA_LOG_INFO("%s: max tensor size = %8.2f MiB\n", __func__, max_size/1024.0/1024.0);
#define LLAMA_METAL_CHECK_BUF(result) \ #define LLAMA_METAL_CHECK_BUF(result) \
if (!(result)) { \ if (!(result)) { \
@ -8647,6 +8749,45 @@ float llama_rope_freq_scale_train(const struct llama_model * model) {
return model->hparams.rope_freq_scale_train; return model->hparams.rope_freq_scale_train;
} }
int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) {
const auto & it = model->gguf_kv.find(key);
if (it == model->gguf_kv.end()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
return snprintf(buf, buf_size, "%s", it->second.c_str());
}
int llama_model_meta_count(const struct llama_model * model) {
return (int)model->gguf_kv.size();
}
int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)model->gguf_kv.size()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
auto it = model->gguf_kv.begin();
std::advance(it, i);
return snprintf(buf, buf_size, "%s", it->first.c_str());
}
int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)model->gguf_kv.size()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
auto it = model->gguf_kv.begin();
std::advance(it, i);
return snprintf(buf, buf_size, "%s", it->second.c_str());
}
int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
return snprintf(buf, buf_size, "%s %s %s", return snprintf(buf, buf_size, "%s %s %s",
llama_model_arch_name(model->arch).c_str(), llama_model_arch_name(model->arch).c_str(),
@ -9288,6 +9429,14 @@ llama_token llama_token_nl(const struct llama_model * model) {
return model->vocab.linefeed_id; return model->vocab.linefeed_id;
} }
int llama_add_bos_token(const struct llama_model * model) {
return model->vocab.special_add_bos;
}
int llama_add_eos_token(const struct llama_model * model) {
return model->vocab.special_add_eos;
}
llama_token llama_token_prefix(const struct llama_model * model) { llama_token llama_token_prefix(const struct llama_model * model) {
return model->vocab.special_prefix_id; return model->vocab.special_prefix_id;
} }

23
llama.h
View file

@ -301,6 +301,23 @@ extern "C" {
// Get the model's RoPE frequency scaling factor // Get the model's RoPE frequency scaling factor
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
// Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
// - GGUF array values are not supported by these functions
// Get metadata value as a string by key name
LLAMA_API int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
// Get the number of metadata key/value pairs
LLAMA_API int llama_model_meta_count(const struct llama_model * model);
// Get metadata key name by index
LLAMA_API int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
// Get metadata value as a string by index
LLAMA_API int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
// Get a string describing the model type // Get a string describing the model type
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
@ -517,6 +534,12 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
// Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int llama_add_bos_token(const struct llama_model * model);
// Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int llama_add_eos_token(const struct llama_model * model);
// codellama infill tokens // codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle