Merge branch 'master' into compilade/lazy-convert-hf
This commit is contained in:
commit
bffdaf4010
43 changed files with 1646 additions and 179 deletions
|
@ -103,6 +103,8 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
|
|||
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||
"llama: max. batch size for using peer access")
|
||||
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
|
||||
option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF)
|
||||
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
|
||||
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
|
||||
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
|
||||
|
@ -409,6 +411,9 @@ if (LLAMA_CUDA)
|
|||
if (LLAMA_CUDA_FORCE_MMQ)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
||||
endif()
|
||||
if (LLAMA_CUDA_NO_VMM)
|
||||
add_compile_definitions(GGML_CUDA_NO_VMM)
|
||||
endif()
|
||||
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
|
||||
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
|
||||
if (DEFINED LLAMA_CUDA_DMMV_Y)
|
||||
|
@ -434,7 +439,11 @@ if (LLAMA_CUDA)
|
|||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
endif()
|
||||
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
|
||||
if (LLAMA_CUDA_NO_VMM)
|
||||
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
|
||||
else()
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
# 52 == lowest CUDA 12 standard
|
||||
|
|
50
README.md
50
README.md
|
@ -20,7 +20,8 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
|||
|
||||
### Hot topics
|
||||
|
||||
- **BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920**
|
||||
- **Initial Flash-Attention support: https://github.com/ggerganov/llama.cpp/pull/5021**
|
||||
- BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920
|
||||
- MoE memory layout has been updated - reconvert models for `mmap` support and regenerate `imatrix` https://github.com/ggerganov/llama.cpp/pull/6387
|
||||
- Model sharding instructions using `gguf-split` https://github.com/ggerganov/llama.cpp/discussions/6404
|
||||
- Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225
|
||||
|
@ -935,17 +936,25 @@ If your issue is with model generation quality, then please at least scan the fo
|
|||
|
||||
### Android
|
||||
|
||||
#### Build on Android using Termux
|
||||
[Termux](https://github.com/termux/termux-app#installation) is a method to execute `llama.cpp` on an Android device (no root required).
|
||||
```
|
||||
apt update && apt upgrade -y
|
||||
apt install git make cmake
|
||||
```
|
||||
|
||||
It's recommended to move your model inside the `~/` directory for best performance:
|
||||
```
|
||||
cd storage/downloads
|
||||
mv model.gguf ~/
|
||||
```
|
||||
|
||||
[Get the code](https://github.com/ggerganov/llama.cpp#get-the-code) & [follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`.
|
||||
|
||||
#### Building the Project using Android NDK
|
||||
You can easily run `llama.cpp` on Android device with [termux](https://termux.dev/).
|
||||
|
||||
First, install the essential packages for termux:
|
||||
```
|
||||
pkg install clang wget git cmake
|
||||
```
|
||||
Second, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake:
|
||||
|
||||
You can execute the following commands on your computer to avoid downloading the NDK to your mobile. Of course, you can also do this in Termux.
|
||||
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
|
||||
|
||||
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
|
||||
```
|
||||
$ mkdir build-android
|
||||
$ cd build-android
|
||||
|
@ -953,7 +962,9 @@ $ export NDK=<your_ndk_directory>
|
|||
$ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
|
||||
$ make
|
||||
```
|
||||
Install [termux](https://termux.dev/) on your device and run `termux-setup-storage` to get access to your SD card.
|
||||
|
||||
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
|
||||
|
||||
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
|
||||
|
||||
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
|
||||
|
@ -975,25 +986,10 @@ $cd /data/data/com.termux/files/home/bin
|
|||
$./main -m ../model/llama-2-7b-chat.Q4_K_M.gguf -n 128 -cml
|
||||
```
|
||||
|
||||
Here is a demo of an interactive session running on Pixel 5 phone:
|
||||
Here's a demo of an interactive session running on Pixel 5 phone:
|
||||
|
||||
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
|
||||
|
||||
#### Build on Android using Termux
|
||||
[Termux](https://github.com/termux/termux-app#installation) is an alternative to execute `llama.cpp` on an Android device (no root required).
|
||||
```
|
||||
apt update && apt upgrade -y
|
||||
apt install git
|
||||
```
|
||||
|
||||
It's recommended to move your model inside the `~/` directory for best performance:
|
||||
```
|
||||
cd storage/downloads
|
||||
mv model.gguf ~/
|
||||
```
|
||||
|
||||
[Follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`.
|
||||
|
||||
### Docker
|
||||
|
||||
#### Prerequisites
|
||||
|
|
11
ci/run.sh
11
ci/run.sh
|
@ -160,9 +160,8 @@ function gg_run_test_scripts_debug {
|
|||
|
||||
set -e
|
||||
|
||||
# TODO: too slow, run on dedicated node
|
||||
#(cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
|
||||
#(cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
|
||||
(cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
|
||||
(cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log
|
||||
|
||||
set +e
|
||||
}
|
||||
|
@ -695,8 +694,10 @@ test $ret -eq 0 && gg_run ctest_release
|
|||
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
test $ret -eq 0 && gg_run embd_bge_small
|
||||
|
||||
test $ret -eq 0 && gg_run test_scripts_debug
|
||||
test $ret -eq 0 && gg_run test_scripts_release
|
||||
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
|
||||
test $ret -eq 0 && gg_run test_scripts_debug
|
||||
test $ret -eq 0 && gg_run test_scripts_release
|
||||
fi
|
||||
|
||||
if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then
|
||||
if [ -z ${GG_BUILD_CUDA} ]; then
|
||||
|
|
|
@ -911,6 +911,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
params.instruct = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-cnv" || arg == "--conversation") {
|
||||
params.conversation = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-cml" || arg == "--chatml") {
|
||||
params.chatml = true;
|
||||
return true;
|
||||
|
@ -1417,6 +1421,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" --version show version and build info\n");
|
||||
printf(" -i, --interactive run in interactive mode\n");
|
||||
printf(" --interactive-first run in interactive mode and wait for input right away\n");
|
||||
printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n");
|
||||
printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n");
|
||||
printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n");
|
||||
printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
|
||||
|
|
|
@ -140,6 +140,7 @@ struct gpt_params {
|
|||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
bool interactive = false; // interactive mode
|
||||
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
|
||||
bool chatml = false; // chatml mode (used for models trained on chatml syntax)
|
||||
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
||||
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
||||
|
|
|
@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
|||
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
result->n_considered = 0;
|
||||
|
||||
llama_sampling_set_rng_seed(result, params.seed);
|
||||
|
||||
return result;
|
||||
|
@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
|||
|
||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||
ctx->cur.clear();
|
||||
ctx->n_considered = 0;
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
||||
|
@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl(
|
|||
}
|
||||
}
|
||||
|
||||
ctx_sampling->n_considered = cur_p.size;
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
|
|
|
@ -81,6 +81,7 @@ struct llama_sampling_context {
|
|||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
size_t n_considered;
|
||||
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
|
|
@ -67,6 +67,9 @@ models = [
|
|||
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
||||
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
|
||||
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
||||
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
||||
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
||||
]
|
||||
|
||||
# make directory "models/tokenizers" if it doesn't exist
|
||||
|
@ -150,6 +153,8 @@ for model in models:
|
|||
# print the "pre_tokenizer" content from the tokenizer.json
|
||||
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
normalizer = cfg["normalizer"]
|
||||
logger.info("normalizer: " + json.dumps(normalizer, indent=4))
|
||||
pre_tokenizer = cfg["pre_tokenizer"]
|
||||
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
|
||||
|
||||
|
|
|
@ -397,6 +397,15 @@ class Model:
|
|||
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
|
||||
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
|
||||
res = "command-r"
|
||||
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
|
||||
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
|
||||
res = "qwen2"
|
||||
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
||||
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
|
||||
res = "olmo"
|
||||
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
|
||||
# ref: https://huggingface.co/databricks/dbrx-instruct
|
||||
res = "dbrx"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
@ -2248,8 +2257,9 @@ class OlmoModel(Model):
|
|||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_layer_norm_eps(1e-5)
|
||||
if "clip_qkv" in self.hparams is not None:
|
||||
self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"])
|
||||
clip_qkv = self.hparams.get("clip_qkv")
|
||||
if clip_qkv is not None:
|
||||
self.gguf_writer.add_clamp_kqv(clip_qkv)
|
||||
|
||||
# Same as super class, but permuting q_proj, k_proj
|
||||
# Copied from: LlamaModel
|
||||
|
|
47
convert.py
47
convert.py
|
@ -1512,25 +1512,27 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
if args.big_endian:
|
||||
endianess = gguf.GGUFEndian.BIG
|
||||
|
||||
params = Params.load(model_plus)
|
||||
if params.n_ctx == -1:
|
||||
if args.ctx is None:
|
||||
msg = """\
|
||||
The model doesn't have a context size, and you didn't specify one with --ctx
|
||||
Please specify one with --ctx:
|
||||
- LLaMA v1: --ctx 2048
|
||||
- LLaMA v2: --ctx 4096"""
|
||||
parser.error(textwrap.dedent(msg))
|
||||
params.n_ctx = args.ctx
|
||||
params = None
|
||||
if args.pad_vocab or not args.vocab_only:
|
||||
params = Params.load(model_plus)
|
||||
if params.n_ctx == -1:
|
||||
if args.ctx is None:
|
||||
msg = """\
|
||||
The model doesn't have a context size, and you didn't specify one with --ctx
|
||||
Please specify one with --ctx:
|
||||
- LLaMA v1: --ctx 2048
|
||||
- LLaMA v2: --ctx 4096"""
|
||||
parser.error(textwrap.dedent(msg))
|
||||
params.n_ctx = args.ctx
|
||||
|
||||
if args.outtype:
|
||||
params.ftype = {
|
||||
"f32": GGMLFileType.AllF32,
|
||||
"f16": GGMLFileType.MostlyF16,
|
||||
"q8_0": GGMLFileType.MostlyQ8_0,
|
||||
}[args.outtype]
|
||||
if args.outtype:
|
||||
params.ftype = {
|
||||
"f32": GGMLFileType.AllF32,
|
||||
"f16": GGMLFileType.MostlyF16,
|
||||
"q8_0": GGMLFileType.MostlyQ8_0,
|
||||
}[args.outtype]
|
||||
|
||||
logger.info(f"params = {params}")
|
||||
logger.info(f"params = {params}")
|
||||
|
||||
model_parent_path = model_plus.paths[0].parent
|
||||
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
|
||||
|
@ -1543,6 +1545,17 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
if not args.outfile:
|
||||
raise ValueError("need --outfile if using --vocab-only")
|
||||
outfile = args.outfile
|
||||
if params is None:
|
||||
params = Params(
|
||||
n_vocab = vocab.vocab_size,
|
||||
n_embd = 1,
|
||||
n_layer = 1,
|
||||
n_ctx = 1,
|
||||
n_ff = 1,
|
||||
n_head = 1,
|
||||
n_head_kv = 1,
|
||||
f_norm_eps = 1e-5,
|
||||
)
|
||||
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
|
||||
endianess=endianess, pad_vocab=args.pad_vocab)
|
||||
logger.info(f"Wrote {outfile}")
|
||||
|
|
|
@ -23,7 +23,7 @@ Install BLIS:
|
|||
sudo make install
|
||||
```
|
||||
|
||||
We recommend using openmp since it's easier to modify the cores been used.
|
||||
We recommend using openmp since it's easier to modify the cores being used.
|
||||
|
||||
### llama.cpp compilation
|
||||
|
||||
|
|
|
@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc
|
|||
|
||||
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
|
||||
|
||||
Have a look to existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
|
||||
Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
|
||||
|
||||
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support of missing backend operations can be added in another PR.
|
||||
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
|
||||
|
||||
Note: to debug the inference graph: you can use [eval-callback](../examples/eval-callback).
|
||||
|
||||
|
|
|
@ -575,7 +575,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
|
||||
|
||||
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
|
||||
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
|
||||
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16) {
|
||||
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
|
||||
} else if (a->type == GGML_TYPE_F32) {
|
||||
return ggml_add(ctx, a, b);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
struct Stats {
|
||||
std::vector<float> values;
|
||||
std::vector<int> counts;
|
||||
int ncall = 0;
|
||||
};
|
||||
|
||||
|
@ -121,12 +122,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
|||
auto & e = m_stats[wname];
|
||||
|
||||
++e.ncall;
|
||||
// NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
|
||||
// using the following line, we can correct for that if needed by replacing the line above with:
|
||||
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
|
||||
|
||||
if (e.values.empty()) {
|
||||
e.values.resize(src1->ne[0]*n_as, 0);
|
||||
e.counts.resize(src1->ne[0]*n_as, 0);
|
||||
}
|
||||
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
|
||||
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
|
||||
|
@ -153,6 +152,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
|||
|
||||
for (int j = 0; j < (int)src1->ne[0]; ++j) {
|
||||
e.values[e_start + j] += x[j]*x[j];
|
||||
e.counts[e_start + j]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -170,6 +170,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
|||
auto& e = m_stats[wname];
|
||||
if (e.values.empty()) {
|
||||
e.values.resize(src1->ne[0], 0);
|
||||
e.counts.resize(src1->ne[0], 0);
|
||||
}
|
||||
else if (e.values.size() != (size_t)src1->ne[0]) {
|
||||
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
|
||||
|
@ -183,6 +184,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
|||
const float * x = data + row * src1->ne[0];
|
||||
for (int j = 0; j < (int)src1->ne[0]; ++j) {
|
||||
e.values[j] += x[j]*x[j];
|
||||
e.counts[j]++;
|
||||
}
|
||||
}
|
||||
if (e.ncall > m_last_call) {
|
||||
|
@ -222,7 +224,13 @@ void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) co
|
|||
out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
|
||||
int nval = p.second.values.size();
|
||||
out.write((const char *) &nval, sizeof(nval));
|
||||
if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float));
|
||||
if (nval > 0) {
|
||||
std::vector<float> tmp(nval);
|
||||
for (int i = 0; i < nval; i++) {
|
||||
tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall);
|
||||
}
|
||||
out.write((const char*)tmp.data(), nval*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// Write the number of call the matrix was computed with
|
||||
|
@ -270,14 +278,28 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma
|
|||
imatrix_data = {};
|
||||
return false;
|
||||
}
|
||||
e.values.resize(nval);
|
||||
in.read((char*)e.values.data(), nval*sizeof(float));
|
||||
|
||||
// When re-called from load_imatrix() with add set, this will already be created.
|
||||
if (e.values.empty()) {
|
||||
e.values.resize(nval, 0);
|
||||
e.counts.resize(nval, 0);
|
||||
}
|
||||
|
||||
std::vector<float> tmp(nval);
|
||||
in.read((char*)tmp.data(), nval*sizeof(float));
|
||||
if (in.fail()) {
|
||||
printf("%s: failed reading data for entry %d\n",__func__,i);
|
||||
imatrix_data = {};
|
||||
return false;
|
||||
}
|
||||
e.ncall = ncall;
|
||||
|
||||
// Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
|
||||
for (int i = 0; i < nval; i++) {
|
||||
e.values[i] += tmp[i];
|
||||
e.counts[i] += ncall;
|
||||
}
|
||||
e.ncall += ncall;
|
||||
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-pa
|
|||
python ./convert.py ../llava-v1.5-7b --skip-unknown
|
||||
```
|
||||
|
||||
Now both the LLaMA part and the image encoder is in the `llava-v1.5-7b` directory.
|
||||
Now both the LLaMA part and the image encoder are in the `llava-v1.5-7b` directory.
|
||||
|
||||
## LLaVA 1.6 gguf conversion
|
||||
1) First clone a LLaVA 1.6 model:
|
||||
|
|
|
@ -143,7 +143,7 @@ The `--ctx-size` option allows you to set the size of the prompt context used by
|
|||
|
||||
### Extended Context Size
|
||||
|
||||
Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8.
|
||||
Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model has a context length (max sequence length) of 4096 (4k) and the fine-tuned model has 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8.
|
||||
|
||||
- `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model.
|
||||
|
||||
|
@ -286,7 +286,7 @@ These options help improve the performance and memory usage of the LLaMA models.
|
|||
|
||||
- `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilitizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes.
|
||||
- `--numa isolate`: Pin all threads to the NUMA node that the program starts on. This limits the number of cores and amount of memory that can be used, but guarantees all memory access remains local to the NUMA node.
|
||||
- `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitraty core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus.
|
||||
- `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitrary core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus.
|
||||
|
||||
These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root.
|
||||
|
||||
|
|
|
@ -362,6 +362,9 @@ int main(int argc, char ** argv) {
|
|||
params.interactive_first = true;
|
||||
params.antiprompt.emplace_back("<|im_start|>user\n");
|
||||
}
|
||||
else if (params.conversation) {
|
||||
params.interactive_first = true;
|
||||
}
|
||||
|
||||
// enable interactive mode if interactive start is specified
|
||||
if (params.interactive_first) {
|
||||
|
@ -733,7 +736,7 @@ int main(int argc, char ** argv) {
|
|||
// display text
|
||||
if (input_echo && display) {
|
||||
for (auto id : embd) {
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation);
|
||||
printf("%s", token_str.c_str());
|
||||
|
||||
if (embd.size() > 1) {
|
||||
|
@ -796,7 +799,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// deal with end of generation tokens in interactive mode
|
||||
if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
|
||||
LOG("found EOS token\n");
|
||||
LOG("found an EOG token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
if (!params.antiprompt.empty()) {
|
||||
|
@ -816,7 +819,7 @@ int main(int argc, char ** argv) {
|
|||
if (n_past > 0 && is_interacting) {
|
||||
LOG("waiting for user input\n");
|
||||
|
||||
if (params.instruct || params.chatml) {
|
||||
if (params.conversation || params.instruct || params.chatml) {
|
||||
printf("\n> ");
|
||||
}
|
||||
|
||||
|
@ -826,7 +829,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
std::string buffer;
|
||||
if (!params.input_prefix.empty()) {
|
||||
if (!params.input_prefix.empty() && !params.conversation) {
|
||||
LOG("appending input prefix: '%s'\n", params.input_prefix.c_str());
|
||||
printf("%s", params.input_prefix.c_str());
|
||||
}
|
||||
|
@ -850,7 +853,7 @@ int main(int argc, char ** argv) {
|
|||
// Entering a empty line lets the user pass control back
|
||||
if (buffer.length() > 1) {
|
||||
// append input suffix if any
|
||||
if (!params.input_suffix.empty()) {
|
||||
if (!params.input_suffix.empty() && !params.conversation) {
|
||||
LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());
|
||||
printf("%s", params.input_suffix.c_str());
|
||||
}
|
||||
|
|
|
@ -46,7 +46,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
|||
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
|
||||
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
|
||||
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
|
||||
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", },
|
||||
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", },
|
||||
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
|
||||
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
|
||||
// Note: Ensure COPY comes after F32 to avoid ftype 0 from matching.
|
||||
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
|
||||
|
|
|
@ -62,6 +62,18 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
|
|||
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
|
||||
- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled
|
||||
- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json`
|
||||
- `--rope-scaling` : RoPE scaling method. Defaults to linear unless otherwise specified by the model. Options are `none`, `linear`, `yarn`
|
||||
- `--rope-freq-base N` : RoPE frequency base (default: loaded from model)
|
||||
- `--rope-freq-scale N`: RoPE frequency scaling factor, expands context by a factor of 1/N (e.g. 0.25)
|
||||
- `--yarn-ext-factor N` : YaRN: extrapolation mix factor (Default: 1.0, 0.0 = full interpolation)
|
||||
- `--yarn-attn-factor N` : YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
|
||||
- `--yarn-beta-slow N`: YaRN: High correction dim or alpha (default: 1.0)
|
||||
- `--yarn-beta-fast N`: YaRN: low correction dim or beta (default: 32.0)
|
||||
- `--pooling` : Pooling type for embeddings, use model default if unspecified. Options are `none`, `mean`, `cls`
|
||||
- `-dt N`, `--defrag-thold N`: KV cache defragmentation threshold (default: -1.0, < 0 = disabled)
|
||||
- `-fa`, `--flash-attn` : enable flash attention (default: disabled).
|
||||
- `-ctk TYPE`, `--cache-type-k TYPE` : KV cache data type for K (default: `f16`, options `f32`, `f16`, `q8_0`, `q4_0`, `q4_1`, `iq4_nl`, `q5_0`, or `q5_1`)
|
||||
- `-ctv TYPE`, `--cache-type-v TYPE` : KV cache type for V (default `f16`, see `-ctk` for options)
|
||||
|
||||
**If compiled with `LLAMA_SERVER_SSL=ON`**
|
||||
- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key
|
||||
|
@ -260,7 +272,7 @@ node index.js
|
|||
|
||||
`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]`
|
||||
|
||||
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token. Default: `0`
|
||||
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0`
|
||||
|
||||
`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0`
|
||||
|
||||
|
@ -319,7 +331,7 @@ Notice that each `probs` is an array of length `n_probs`.
|
|||
|
||||
`content`: Set the text to tokenize.
|
||||
|
||||
Note that a special `BOS` token is never inserted.
|
||||
`add_special`: Boolean indicating if special tokens, i.e. `BOS`, should be inserted. Default: `false`
|
||||
|
||||
- **POST** `/detokenize`: Convert tokens to text.
|
||||
|
||||
|
|
|
@ -2266,17 +2266,31 @@ struct server_context {
|
|||
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
||||
result.tok = id;
|
||||
|
||||
const int32_t n_probs = slot.sparams.n_probs;
|
||||
if (slot.sparams.temp <= 0 && n_probs > 0) {
|
||||
// for llama_sample_token_greedy we need to sort candidates
|
||||
llama_sample_softmax(ctx, &cur_p);
|
||||
}
|
||||
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
|
||||
if (n_probs > 0) {
|
||||
const size_t n_considered = slot.ctx_sampling->n_considered;
|
||||
|
||||
for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) {
|
||||
result.probs.push_back({
|
||||
cur_p.data[i].id,
|
||||
cur_p.data[i].p
|
||||
});
|
||||
// Make sure at least n_probs top tokens are at the front of the vector:
|
||||
if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
|
||||
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
|
||||
}
|
||||
|
||||
if (slot.sparams.temp == 0.0f) {
|
||||
// With greedy sampling the probabilities have possibly not been calculated.
|
||||
for (size_t i = 0; i < n_probs; ++i) {
|
||||
result.probs.push_back({
|
||||
cur_p.data[i].id,
|
||||
i == 0 ? 1.0f : 0.0f
|
||||
});
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < n_probs; ++i) {
|
||||
result.probs.push_back({
|
||||
cur_p.data[i].id,
|
||||
i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
|
@ -3633,7 +3647,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> tokens;
|
||||
if (body.count("content") != 0) {
|
||||
tokens = ctx_server.tokenize(body["content"], false);
|
||||
const bool add_special = json_value(body, "add_special", false);
|
||||
tokens = ctx_server.tokenize(body["content"], add_special);
|
||||
}
|
||||
const json data = format_tokenizer_response(tokens);
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
|
|
|
@ -7,6 +7,7 @@ Feature: llama.cpp server
|
|||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||
And a model file test-model.gguf
|
||||
And a model alias tinyllama-2
|
||||
And BOS token is 1
|
||||
And 42 as server seed
|
||||
# KV Cache corresponds to the total amount of tokens
|
||||
# that can be stored across all independent sequences: #4130
|
||||
|
@ -91,7 +92,18 @@ Feature: llama.cpp server
|
|||
"""
|
||||
What is the capital of France ?
|
||||
"""
|
||||
Then tokens can be detokenize
|
||||
Then tokens can be detokenized
|
||||
And tokens do not begin with BOS
|
||||
|
||||
Scenario: Tokenize w/ BOS
|
||||
Given adding special tokens
|
||||
When tokenizing:
|
||||
"""
|
||||
What is the capital of Germany?
|
||||
"""
|
||||
Then tokens begin with BOS
|
||||
Given first token is removed
|
||||
Then tokens can be detokenized
|
||||
|
||||
Scenario: Models available
|
||||
Given available models
|
||||
|
|
|
@ -376,6 +376,11 @@ def step_seed(context, seed):
|
|||
context.seed.append(seed)
|
||||
|
||||
|
||||
@step('BOS token is {bos:d}')
|
||||
def step_bos_token(context, bos):
|
||||
context.bos = bos
|
||||
|
||||
|
||||
@step('a prefix prompt')
|
||||
def step_prompt_prefix(context):
|
||||
context.prompt_prefix = context_text(context)
|
||||
|
@ -656,21 +661,29 @@ async def all_embeddings_are_generated(context):
|
|||
assert_embeddings(context.tasks_result.pop().pop())
|
||||
|
||||
|
||||
@step('adding special tokens')
|
||||
def step_tokenize_set_add_special(context):
|
||||
context.tokenize_add_special = True
|
||||
|
||||
|
||||
@step('tokenizing')
|
||||
@async_run_until_complete
|
||||
async def step_tokenize(context):
|
||||
context.tokenized_text = context_text(context)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tokenize_args = {
|
||||
"content": context.tokenized_text,
|
||||
}
|
||||
if getattr(context, 'tokenize_add_special', None) is not None:
|
||||
tokenize_args['add_special'] = context.tokenize_add_special
|
||||
async with session.post(f'{context.base_url}/tokenize',
|
||||
json={
|
||||
"content": context.tokenized_text,
|
||||
}) as response:
|
||||
json=tokenize_args) as response:
|
||||
assert response.status == 200
|
||||
tokenize_json = await response.json()
|
||||
context.tokens = tokenize_json['tokens']
|
||||
|
||||
|
||||
@step('tokens can be detokenize')
|
||||
@step('tokens can be detokenized')
|
||||
@async_run_until_complete
|
||||
async def step_detokenize(context):
|
||||
assert len(context.tokens) > 0
|
||||
|
@ -685,6 +698,21 @@ async def step_detokenize(context):
|
|||
assert context.tokenized_text == detokenize_json['content'].strip()
|
||||
|
||||
|
||||
@step('tokens begin with BOS')
|
||||
def step_strings_for_tokenization(context):
|
||||
assert context.tokens[0] == context.bos
|
||||
|
||||
|
||||
@step('tokens do not begin with BOS')
|
||||
def step_strings_for_tokenization(context):
|
||||
assert context.tokens[0] != context.bos
|
||||
|
||||
|
||||
@step('first token is removed')
|
||||
def step_strings_for_tokenization(context):
|
||||
context.tokens = context.tokens[1:]
|
||||
|
||||
|
||||
@step('an OPTIONS request is sent from {origin}')
|
||||
@async_run_until_complete
|
||||
async def step_options_request(context, origin):
|
||||
|
|
|
@ -49,18 +49,18 @@ extern bool server_log_json;
|
|||
#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
|
||||
static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra);
|
||||
static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra);
|
||||
|
||||
template <typename T>
|
||||
static T json_value(const json &body, const std::string &key, const T &default_value) {
|
||||
static T json_value(const json & body, const std::string & key, const T & default_value) {
|
||||
// Fallback null to default value
|
||||
if (body.contains(key) && !body.at(key).is_null()){
|
||||
if (body.contains(key) && !body.at(key).is_null()) {
|
||||
try {
|
||||
return body.value(key, default_value);
|
||||
}
|
||||
catch (nlohmann::json_abi_v3_11_3::detail::type_error const&){
|
||||
std::string message = "Wrong type supplied for parameter '" + key + "'. Expected '" + typeid(default_value).name() + "', using default value.";
|
||||
server_log("WARN", __func__, __LINE__, message.c_str(), body);
|
||||
return body.at(key);
|
||||
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
|
||||
std::stringstream ss;
|
||||
ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value.";
|
||||
LOG_WARNING(ss.str().c_str(), body);
|
||||
return default_value;
|
||||
}
|
||||
} else {
|
||||
|
@ -68,16 +68,16 @@ static T json_value(const json &body, const std::string &key, const T &default_v
|
|||
}
|
||||
}
|
||||
|
||||
static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) {
|
||||
static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra) {
|
||||
std::stringstream ss_tid;
|
||||
ss_tid << std::this_thread::get_id();
|
||||
json log = nlohmann::ordered_json{
|
||||
json log = json{
|
||||
{"tid", ss_tid.str()},
|
||||
{"timestamp", time(nullptr)},
|
||||
};
|
||||
|
||||
if (server_log_json) {
|
||||
log.merge_patch( {
|
||||
log.merge_patch({
|
||||
{"level", level},
|
||||
{"function", function},
|
||||
{"line", line},
|
||||
|
@ -98,7 +98,7 @@ static inline void server_log(const char *level, const char *function, int line,
|
|||
}
|
||||
std::stringstream ss;
|
||||
ss << buf << " |";
|
||||
for (const auto& el : log.items())
|
||||
for (const auto & el : log.items())
|
||||
{
|
||||
const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
ss << " " << el.key() << "=" << value;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# llama.cpp/example/sycl
|
||||
|
||||
This example program provide the tools for llama.cpp for SYCL on Intel GPU.
|
||||
This example program provides the tools for llama.cpp for SYCL on Intel GPU.
|
||||
|
||||
## Tool
|
||||
|
||||
|
|
30
flake.lock
generated
30
flake.lock
generated
|
@ -5,11 +5,11 @@
|
|||
"nixpkgs-lib": "nixpkgs-lib"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1712014858,
|
||||
"narHash": "sha256-sB4SWl2lX95bExY2gMFG5HIzvva5AVMJd4Igm+GpZNw=",
|
||||
"lastModified": 1714641030,
|
||||
"narHash": "sha256-yzcRNDoyVP7+SCNX0wmuDju1NUCt8Dz9+lyUXEI0dbI=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"rev": "9126214d0a59633752a136528f5f3b9aa8565b7d",
|
||||
"rev": "e5d10a24b66c3ea8f150e47dfdb0416ab7c3390e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
@ -20,11 +20,11 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1714076141,
|
||||
"narHash": "sha256-Drmja/f5MRHZCskS6mvzFqxEaZMeciScCTFxWVLqWEY=",
|
||||
"lastModified": 1714635257,
|
||||
"narHash": "sha256-4cPymbty65RvF1DWQfc+Bc8B233A1BWxJnNULJKQ1EY=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "7bb2ccd8cdc44c91edba16c48d2c8f331fb3d856",
|
||||
"rev": "63c3a29ca82437c87573e4c6919b09a24ea61b0f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
@ -36,20 +36,14 @@
|
|||
},
|
||||
"nixpkgs-lib": {
|
||||
"locked": {
|
||||
"dir": "lib",
|
||||
"lastModified": 1711703276,
|
||||
"narHash": "sha256-iMUFArF0WCatKK6RzfUJknjem0H9m4KgorO/p3Dopkk=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "d8fe5e6c92d0d190646fb9f1056741a229980089",
|
||||
"type": "github"
|
||||
"lastModified": 1714640452,
|
||||
"narHash": "sha256-QBx10+k6JWz6u7VsohfSw8g8hjdBZEf8CFzXH1/1Z94=",
|
||||
"type": "tarball",
|
||||
"url": "https://github.com/NixOS/nixpkgs/archive/50eb7ecf4cd0a5756d7275c8ba36790e5bd53e33.tar.gz"
|
||||
},
|
||||
"original": {
|
||||
"dir": "lib",
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
"type": "tarball",
|
||||
"url": "https://github.com/NixOS/nixpkgs/archive/50eb7ecf4cd0a5756d7275c8ba36790e5bd53e33.tar.gz"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
|
|
|
@ -113,7 +113,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
for (int id = 0; id < info.device_count; ++id) {
|
||||
int device_vmm = 0;
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||
CUdevice device;
|
||||
CU_CHECK(cuDeviceGet(&device, id));
|
||||
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
||||
|
@ -259,7 +259,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
|||
};
|
||||
|
||||
// pool with virtual memory
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
||||
|
||||
|
@ -356,7 +356,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
|||
#endif // !defined(GGML_USE_HIPBLAS)
|
||||
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||
if (ggml_cuda_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||
}
|
||||
|
|
77
ggml-impl.h
77
ggml-impl.h
|
@ -17,6 +17,83 @@
|
|||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
/**
|
||||
* Converts brain16 to float32.
|
||||
*
|
||||
* The bfloat16 floating point format has the following structure:
|
||||
*
|
||||
* ┌sign
|
||||
* │
|
||||
* │ ┌exponent
|
||||
* │ │
|
||||
* │ │ ┌mantissa
|
||||
* │ │ │
|
||||
* │┌──┴───┐┌─┴───┐
|
||||
* 0b0000000000000000 brain16
|
||||
*
|
||||
* Since bf16 has the same number of exponent bits as a 32bit float,
|
||||
* encoding and decoding numbers becomes relatively straightforward.
|
||||
*
|
||||
* ┌sign
|
||||
* │
|
||||
* │ ┌exponent
|
||||
* │ │
|
||||
* │ │ ┌mantissa
|
||||
* │ │ │
|
||||
* │┌──┴───┐┌─┴───────────────────┐
|
||||
* 0b00000000000000000000000000000000 IEEE binary32
|
||||
*
|
||||
* For comparison, the standard fp16 format has fewer exponent bits.
|
||||
*
|
||||
* ┌sign
|
||||
* │
|
||||
* │ ┌exponent
|
||||
* │ │
|
||||
* │ │ ┌mantissa
|
||||
* │ │ │
|
||||
* │┌─┴─┐┌─┴──────┐
|
||||
* 0b0000000000000000 IEEE binary16
|
||||
*
|
||||
* @see IEEE 754-2008
|
||||
*/
|
||||
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} u;
|
||||
u.i = (uint32_t)h.bits << 16;
|
||||
return u.f;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts float32 to brain16.
|
||||
*
|
||||
* This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
|
||||
* Subnormals shall be flushed to zero, and NANs will be quiet.
|
||||
* This code should vectorize nicely if using modern compilers.
|
||||
*/
|
||||
static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
||||
ggml_bf16_t h;
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} u;
|
||||
u.f = s;
|
||||
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
|
||||
h.bits = (u.i >> 16) | 64; /* force to quiet */
|
||||
return h;
|
||||
}
|
||||
if (!(u.i & 0x7f800000)) { /* subnormal */
|
||||
h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
|
||||
return h;
|
||||
}
|
||||
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
|
||||
return h;
|
||||
}
|
||||
|
||||
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
||||
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -803,7 +803,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
return op->ne[3] == 1;
|
||||
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
|
|
|
@ -2175,7 +2175,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
|
||||
const short D4 = D/4;
|
||||
const short D8 = D/8;
|
||||
const short Q8 = Q/8;
|
||||
//const short Q8 = Q/8;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
|
|
|
@ -12450,6 +12450,24 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
|||
const size_t nb = nbytes/ggml_type_size(type);
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
int nans = 0;
|
||||
int infs = 0;
|
||||
const unsigned short * f = (const unsigned short *) data;
|
||||
for (size_t i = 0; i < nb; ++i) {
|
||||
nans += (f[i] & 0x7fff) > 0x7f80;
|
||||
infs += (f[i] & 0x7fff) == 0x7f80;
|
||||
}
|
||||
if (nans) {
|
||||
fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
|
||||
return false;
|
||||
}
|
||||
if (infs) {
|
||||
fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
const ggml_fp16_t * f = (const ggml_fp16_t *) data;
|
||||
|
|
20
ggml.h
20
ggml.h
|
@ -326,14 +326,20 @@ extern "C" {
|
|||
// get ggml_status name string
|
||||
GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
|
||||
|
||||
// ieee 754-2008 half-precision float16
|
||||
// todo: make this not an integral type
|
||||
typedef uint16_t ggml_fp16_t;
|
||||
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t);
|
||||
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
|
||||
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
|
||||
GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
|
||||
|
||||
// convert FP16 <-> FP32
|
||||
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
|
||||
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
|
||||
|
||||
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n);
|
||||
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n);
|
||||
// google brain half-precision bfloat16
|
||||
typedef struct { uint16_t bits; } ggml_bf16_t;
|
||||
GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
|
||||
GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16
|
||||
GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
|
||||
GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
|
||||
|
||||
struct ggml_object;
|
||||
struct ggml_context;
|
||||
|
@ -370,6 +376,7 @@ extern "C" {
|
|||
GGML_TYPE_I64 = 27,
|
||||
GGML_TYPE_F64 = 28,
|
||||
GGML_TYPE_IQ1_M = 29,
|
||||
GGML_TYPE_BF16 = 30,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
|
@ -410,6 +417,7 @@ extern "C" {
|
|||
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
|
|
|
@ -817,6 +817,7 @@ class GGMLQuantizationType(IntEnum):
|
|||
I64 = 27
|
||||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
|
||||
|
||||
class GGUFEndian(IntEnum):
|
||||
|
@ -888,6 +889,7 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
|||
GGMLQuantizationType.I64: (1, 8),
|
||||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ single-line ::= [^\n]+ "\n"`
|
|||
|
||||
## Sequences and Alternatives
|
||||
|
||||
The order of symbols in a sequence matter. For example, in `"1. " move " " move "\n"`, the `"1. "` must come before the first `move`, etc.
|
||||
The order of symbols in a sequence matters. For example, in `"1. " move " " move "\n"`, the `"1. "` must come before the first `move`, etc.
|
||||
|
||||
Alternatives, denoted by `|`, give different sequences that are acceptable. For example, in `move ::= pawn | nonpawn | castle`, `move` can be a `pawn` move, a `nonpawn` move, or a `castle`.
|
||||
|
||||
|
|
38
llama.cpp
38
llama.cpp
|
@ -3175,6 +3175,7 @@ struct llama_model_loader {
|
|||
switch (type_max) {
|
||||
case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break;
|
||||
case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break;
|
||||
case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break;
|
||||
case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break;
|
||||
case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break;
|
||||
case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break;
|
||||
|
@ -3666,6 +3667,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
|||
switch (ftype) {
|
||||
case LLAMA_FTYPE_ALL_F32: return "all F32";
|
||||
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
|
||||
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
|
||||
|
@ -4389,6 +4391,15 @@ static void llm_load_vocab(
|
|||
} else if (
|
||||
tokenizer_pre == "command-r") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
|
||||
} else if (
|
||||
tokenizer_pre == "qwen2") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||
} else if (
|
||||
tokenizer_pre == "olmo") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO;
|
||||
} else if (
|
||||
tokenizer_pre == "dbrx") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
|
@ -6126,6 +6137,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
|||
|| !(
|
||||
model.ftype == LLAMA_FTYPE_ALL_F32 ||
|
||||
model.ftype == LLAMA_FTYPE_MOSTLY_F16 ||
|
||||
model.ftype == LLAMA_FTYPE_MOSTLY_BF16 ||
|
||||
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
|
||||
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1
|
||||
)
|
||||
|
@ -12194,6 +12206,7 @@ struct llm_tokenizer_bpe {
|
|||
case LLAMA_VOCAB_TYPE_BPE:
|
||||
switch (vocab.type_pre) {
|
||||
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
|
||||
case LLAMA_VOCAB_PRE_TYPE_DBRX:
|
||||
word_collection = unicode_regex_split(text, {
|
||||
// original regex from tokenizer.json
|
||||
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
|
@ -12248,10 +12261,18 @@ struct llm_tokenizer_bpe {
|
|||
});
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
||||
word_collection = unicode_regex_split(text, {
|
||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||
});
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
||||
word_collection = unicode_regex_split(text, {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
});
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
word_collection = unicode_regex_split(text, {
|
||||
|
@ -14154,13 +14175,16 @@ static void llama_tensor_dequantize_internal(
|
|||
if (qtype.to_float == NULL) {
|
||||
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
|
||||
}
|
||||
} else if (tensor->type != GGML_TYPE_F16) {
|
||||
} else if (tensor->type != GGML_TYPE_F16 &&
|
||||
tensor->type != GGML_TYPE_BF16) {
|
||||
throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
|
||||
}
|
||||
|
||||
if (nthread < 2) {
|
||||
if (tensor->type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
|
||||
} else if (tensor->type == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
|
||||
} else if (ggml_is_quantized(tensor->type)) {
|
||||
qtype.to_float(tensor->data, f32_output, nelements);
|
||||
} else {
|
||||
|
@ -14169,7 +14193,14 @@ static void llama_tensor_dequantize_internal(
|
|||
return;
|
||||
}
|
||||
|
||||
size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type);
|
||||
size_t block_size;
|
||||
if (tensor->type == GGML_TYPE_F16 ||
|
||||
tensor->type == GGML_TYPE_BF16) {
|
||||
block_size = 1;
|
||||
} else {
|
||||
block_size = (size_t)ggml_blck_size(tensor->type);
|
||||
}
|
||||
|
||||
size_t block_size_bytes = ggml_type_size(tensor->type);
|
||||
|
||||
GGML_ASSERT(nelements % block_size == 0);
|
||||
|
@ -14188,6 +14219,8 @@ static void llama_tensor_dequantize_internal(
|
|||
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
|
||||
if (typ == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
|
||||
} else if (typ == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
|
||||
} else {
|
||||
qtype.to_float(inbuf, outbuf, nels);
|
||||
}
|
||||
|
@ -14548,6 +14581,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
|
||||
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
||||
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
||||
|
||||
// K-quants
|
||||
|
|
4
llama.h
4
llama.h
|
@ -81,6 +81,9 @@ extern "C" {
|
|||
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
||||
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
||||
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
||||
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
|
||||
LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
|
||||
LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
|
@ -136,6 +139,7 @@ extern "C" {
|
|||
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
|
BIN
models/ggml-vocab-qwen2.gguf
Normal file
BIN
models/ggml-vocab-qwen2.gguf
Normal file
Binary file not shown.
106
models/ggml-vocab-qwen2.gguf.inp
Normal file
106
models/ggml-vocab-qwen2.gguf.inp
Normal file
|
@ -0,0 +1,106 @@
|
|||
ied 4 ½ months
|
||||
__ggml_vocab_test__
|
||||
Führer
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
this is 🦙.cpp
|
||||
__ggml_vocab_test__
|
||||
w048 7tuijk dsdfhu
|
||||
__ggml_vocab_test__
|
||||
нещо на Български
|
||||
__ggml_vocab_test__
|
||||
កាន់តែពិសេសអាចខលចេញ
|
||||
__ggml_vocab_test__
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
(
|
||||
__ggml_vocab_test__
|
||||
|
||||
=
|
||||
__ggml_vocab_test__
|
||||
' era
|
||||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
__ggml_vocab_test__
|
||||
333
|
||||
__ggml_vocab_test__
|
||||
3333
|
||||
__ggml_vocab_test__
|
||||
33333
|
||||
__ggml_vocab_test__
|
||||
333333
|
||||
__ggml_vocab_test__
|
||||
3333333
|
||||
__ggml_vocab_test__
|
||||
33333333
|
||||
__ggml_vocab_test__
|
||||
333333333
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
|
||||
__ggml_vocab_test__
|
43
models/ggml-vocab-qwen2.gguf.out
Normal file
43
models/ggml-vocab-qwen2.gguf.out
Normal file
|
@ -0,0 +1,43 @@
|
|||
1122 220 19 220 26062 3951
|
||||
37 50753 261
|
||||
|
||||
220
|
||||
256
|
||||
262
|
||||
197
|
||||
198
|
||||
271
|
||||
1406
|
||||
1572
|
||||
9707 1879
|
||||
21927 1879
|
||||
9707 4337
|
||||
21927 4337
|
||||
21927 4337 0
|
||||
9707 11 1879 0
|
||||
21927 11 1879 0
|
||||
419 374 11162 99 247 13 10821
|
||||
86 15 19 23 220 22 83 1963 41808 11472 2940 16739
|
||||
78762 14144 1456 13073 63471 33594 3038 133178 79012
|
||||
146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848
|
||||
145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8
|
||||
9707
|
||||
21927
|
||||
220 21927
|
||||
256 21927
|
||||
262 21927
|
||||
262 21927 198 262 21927
|
||||
320
|
||||
198 284
|
||||
6 11385
|
||||
9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216
|
||||
18
|
||||
18 18
|
||||
18 18 18
|
||||
18 18 18 18
|
||||
18 18 18 18 18
|
||||
18 18 18 18 18 18
|
||||
18 18 18 18 18 18 18
|
||||
18 18 18 18 18 18 18 18
|
||||
18 18 18 18 18 18 18 18 18
|
||||
198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43
|
|
@ -93,11 +93,14 @@ help_s = (
|
|||
"specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench."
|
||||
)
|
||||
parser.add_argument("-s", "--show", help=help_s)
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
|
||||
known_args, unknown_args = parser.parse_known_args()
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
|
||||
|
||||
if unknown_args:
|
||||
logger.error(f"Received unknown args: {unknown_args}.")
|
||||
logger.error(f"Received unknown args: {unknown_args}.\n")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
@ -110,7 +113,7 @@ if input_file is None:
|
|||
input_file = sqlite_files[0]
|
||||
|
||||
if input_file is None:
|
||||
logger.error("Cannot find a suitable input file, please provide one.")
|
||||
logger.error("Cannot find a suitable input file, please provide one.\n")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
@ -202,12 +205,12 @@ elif repo is not None:
|
|||
hexsha8_baseline = find_parent_in_data(repo.heads.master.commit)
|
||||
|
||||
if hexsha8_baseline is None:
|
||||
logger.error("No baseline was provided and did not find data for any master branch commits.")
|
||||
logger.error("No baseline was provided and did not find data for any master branch commits.\n")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.error("No baseline was provided and the current working directory "
|
||||
"is not part of a git repository from which a baseline could be inferred.")
|
||||
"is not part of a git repository from which a baseline could be inferred.\n")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
@ -238,7 +241,7 @@ elif repo is not None:
|
|||
break
|
||||
|
||||
if hexsha8_compare is None:
|
||||
logger.error("No compare target was provided and did not find data for any non-master commits.")
|
||||
logger.error("No compare target was provided and did not find data for any non-master commits.\n")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
else:
|
||||
|
@ -361,7 +364,7 @@ if "gpu_info" in show:
|
|||
headers = [PRETTY_NAMES[p] for p in show]
|
||||
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
|
||||
|
||||
logger.info(tabulate(
|
||||
print(tabulate( # noqa: NP100
|
||||
table,
|
||||
headers=headers,
|
||||
floatfmt=".2f",
|
||||
|
|
77
sgemm.cpp
77
sgemm.cpp
|
@ -1,6 +1,3 @@
|
|||
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
||||
// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
|
||||
//
|
||||
// Copyright 2024 Mozilla Foundation
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining
|
||||
|
@ -585,15 +582,15 @@ class tinyBLAS_Q0_ARM {
|
|||
};
|
||||
#endif // __ARM_FEATURE_DOTPROD
|
||||
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
||||
template <typename TA, typename TB, typename TC>
|
||||
class tinyBLAS_Q0_AVX2 {
|
||||
class tinyBLAS_Q0_AVX {
|
||||
public:
|
||||
tinyBLAS_Q0_AVX2(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const TB *B, int64_t ldb,
|
||||
TC *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
tinyBLAS_Q0_AVX(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const TB *B, int64_t ldb,
|
||||
TC *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
}
|
||||
|
||||
|
@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 {
|
|||
__m256 Cv[RN][RM] = {};
|
||||
for (int64_t l = 0; l < k; ++l)
|
||||
for (int64_t j = 0; j < RN; ++j)
|
||||
for (int64_t i = 0; i < RM; ++i)
|
||||
for (int64_t i = 0; i < RM; ++i) {
|
||||
#if defined(__AVX2__)
|
||||
__m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
||||
load(A + lda * (ii + i) + l)),
|
||||
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
||||
load(A + lda * (ii + i) + l)));
|
||||
#else
|
||||
__m128i ali0 = load0(A + lda * (ii + i) + l);
|
||||
__m128i ali1 = load1(A + lda * (ii + i) + l);
|
||||
__m128i blj0 = load0(B + ldb * (jj + j) + l);
|
||||
__m128i blj1 = load1(B + ldb * (jj + j) + l);
|
||||
|
||||
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
|
||||
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
|
||||
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
|
||||
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
|
||||
|
||||
// updot
|
||||
const __m128i oneFill = _mm_set1_epi16(1);
|
||||
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
|
||||
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
|
||||
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
|
||||
#endif
|
||||
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
||||
unhalf(B[ldb * (jj + j) + l].d)),
|
||||
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
||||
load(A + lda * (ii + i) + l)),
|
||||
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
||||
load(A + lda * (ii + i) + l))),
|
||||
Cv[j][i]);
|
||||
udTmp,
|
||||
Cv[j][i]);
|
||||
}
|
||||
for (int64_t j = 0; j < RN; ++j)
|
||||
for (int64_t i = 0; i < RM; ++i)
|
||||
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
||||
|
@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
|
|||
return _mm256_loadu_si256((const __m256i *)b->qs);
|
||||
}
|
||||
|
||||
inline __m128i load0(const block_q8_0 *b) {
|
||||
return _mm_loadu_si128((const __m128i *)b->qs);
|
||||
}
|
||||
|
||||
inline __m128i load1(const block_q8_0 *b) {
|
||||
return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
|
||||
}
|
||||
|
||||
inline __m256i load(const block_q4_0 *b) {
|
||||
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
|
||||
}
|
||||
|
||||
inline __m128i load0(const block_q4_0 *b) {
|
||||
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
||||
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
|
||||
}
|
||||
|
||||
inline __m128i load1(const block_q4_0 *b) {
|
||||
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
||||
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
||||
}
|
||||
|
||||
inline __m256 updot(__m256i u, __m256i s) {
|
||||
__m256i res;
|
||||
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
||||
|
@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 {
|
|||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
#endif // __AVX2__
|
||||
#endif // __AVX__
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
case GGML_TYPE_Q8_0: {
|
||||
if (Btype != GGML_TYPE_Q8_0)
|
||||
return false;
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{
|
||||
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
||||
tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
|
||||
k, (const block_q8_0 *)A, lda,
|
||||
(const block_q8_0 *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
|
@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
case GGML_TYPE_Q4_0: {
|
||||
if (Btype != GGML_TYPE_Q8_0)
|
||||
return false;
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{
|
||||
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
||||
tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
|
||||
k, (const block_q4_0 *)A, lda,
|
||||
(const block_q8_0 *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
|
|
|
@ -84,6 +84,7 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE
|
|||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-command-r.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-qwen2.gguf)
|
||||
|
||||
# build test-tokenizer-1-bpe target once and add many tests
|
||||
add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)
|
||||
|
|
|
@ -50,7 +50,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
|||
|
||||
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
|
||||
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
|
||||
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
|
||||
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
|
||||
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
|
||||
std::vector<uint8_t> dataq(ggml_row_size(tensor->type, size));
|
||||
std::vector<float> imatrix(tensor->ne[0], 1.0f); // dummy importance matrix
|
||||
|
@ -92,6 +92,8 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
|
|||
size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
|
||||
if (t->type == GGML_TYPE_F16) {
|
||||
tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
|
||||
} else if (t->type == GGML_TYPE_BF16) {
|
||||
tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));
|
||||
} else if (t->type == GGML_TYPE_F32) {
|
||||
tv.push_back(*(float *) &buf[i]);
|
||||
} else if (t->type == GGML_TYPE_I32) {
|
||||
|
@ -1898,7 +1900,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
std::default_random_engine rng(0);
|
||||
|
||||
const ggml_type all_types[] = {
|
||||
GGML_TYPE_F32, GGML_TYPE_F16,
|
||||
GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
|
||||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue