Merge branch 'master' into logging_callback

# Conflicts:
#	llama.cpp
#	llama.h
This commit is contained in:
grahameth 2023-07-23 22:14:05 +02:00
commit 152b633691
15 changed files with 740 additions and 827 deletions

View file

@ -235,13 +235,15 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
endif # LLAMA_CUBLAS endif # LLAMA_CUBLAS
ifdef LLAMA_CLBLAST ifdef LLAMA_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
CXXFLAGS += -DGGML_USE_CLBLAST CFLAGS += -DGGML_USE_CLBLAST $(shell pkg-config --cflags clblast OpenCL)
CXXFLAGS += -DGGML_USE_CLBLAST $(shell pkg-config --cflags clblast OpenCL)
# Mac provides OpenCL as a framework # Mac provides OpenCL as a framework
ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_S),Darwin)
LDFLAGS += -lclblast -framework OpenCL LDFLAGS += -lclblast -framework OpenCL
else else
LDFLAGS += -lclblast -lOpenCL LDFLAGS += $(shell pkg-config --libs clblast OpenCL)
endif endif
OBJS += ggml-opencl.o OBJS += ggml-opencl.o

View file

@ -242,6 +242,23 @@ In order to build llama.cpp you have three different options.
zig build -Doptimize=ReleaseFast zig build -Doptimize=ReleaseFast
``` ```
- Using `gmake` (FreeBSD):
1. Install and activate [DRM in FreeBSD](https://wiki.freebsd.org/Graphics)
2. Add your user to **video** group
3. Install compilation dependencies.
```bash
sudo pkg install gmake automake autoconf pkgconf llvm15 clinfo clover \
opencl clblast openblas
gmake CC=/usr/local/bin/clang15 CXX=/usr/local/bin/clang++15 -j4
```
**Notes:** With this packages you can build llama.cpp with OPENBLAS and
CLBLAST support for use OpenCL GPU acceleration in FreeBSD. Please read
the instructions for use and activate this options in this document below.
### Metal Build ### Metal Build
Using Metal allows the computation to be executed on the GPU for Apple devices: Using Metal allows the computation to be executed on the GPU for Apple devices:
@ -384,7 +401,7 @@ Building the program with BLAS support may lead to some performance improvements
| Option | Legal values | Default | Description | | Option | Legal values | Default | Description |
|-------------------------|------------------------|---------|-------------| |-------------------------|------------------------|---------|-------------|
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 7.0/Turing/RTX 2000 or higher). Does not affect k-quants. | | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
| LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |

View file

@ -194,13 +194,38 @@ class Params:
n_layer = n_layer, n_layer = n_layer,
) )
# LLaMA v2 70B params.json
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1
@staticmethod
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
config = json.load(open(config_path))
n_vocab = config["vocab_size"];
n_embd = config["dim"];
n_head = config["n_heads"];
n_layer = config["n_layers"];
n_mult = config["multiple_of"];
if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0]
return Params(
n_vocab = n_vocab,
n_embd = n_embd,
n_mult = n_mult,
n_head = n_head,
n_layer = n_layer,
)
@staticmethod @staticmethod
def load(model_plus: 'ModelPlus') -> 'Params': def load(model_plus: 'ModelPlus') -> 'Params':
hf_config_path = model_plus.paths[0].parent / "config.json"
orig_config_path = model_plus.paths[0].parent / "params.json" orig_config_path = model_plus.paths[0].parent / "params.json"
hf_transformer_config_path = model_plus.paths[0].parent / "config.json"
if hf_transformer_config_path.exists(): if hf_config_path.exists():
params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path) params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
elif orig_config_path.exists():
params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
else: else:
params = Params.guessed(model_plus.model) params = Params.guessed(model_plus.model)
@ -1036,8 +1061,7 @@ class OutputFile:
@staticmethod @staticmethod
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
of = OutputFile(fname_out) of = OutputFile(fname_out)
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
n_head=1, n_layer=0)
of = OutputFile(fname_out) of = OutputFile(fname_out)
of.write_file_header(params, file_type=GGMLFileType.AllF32) of.write_file_header(params, file_type=GGMLFileType.AllF32)
of.write_vocab(vocab) of.write_vocab(vocab)

View file

@ -117,6 +117,9 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.n_threads = std::stoi(argv[i]); params.n_threads = std::stoi(argv[i]);
if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency();
}
} else if (arg == "-p" || arg == "--prompt") { } else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -168,6 +171,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.n_ctx = std::stoi(argv[i]); params.n_ctx = std::stoi(argv[i]);
} else if (arg == "-gqa" || arg == "--gqa") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_gqa = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") { } else if (arg == "--rope-freq-base") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -458,91 +467,92 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stdout, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stdout, "\n");
fprintf(stderr, "options:\n"); fprintf(stdout, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stdout, " -h, --help show this help message and exit\n");
fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stdout, " -i, --interactive run in interactive mode\n");
fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stdout, " --interactive-first run in interactive mode and wait for input right away\n");
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stdout, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stdout, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stdout, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " halt generation at PROMPT, return control in interactive mode\n"); fprintf(stdout, " halt generation at PROMPT, return control in interactive mode\n");
fprintf(stderr, " (can be specified more than once for multiple prompts).\n"); fprintf(stdout, " (can be specified more than once for multiple prompts).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stdout, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(stdout, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); fprintf(stdout, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: empty)\n"); fprintf(stdout, " prompt to start generation with (default: empty)\n");
fprintf(stderr, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); fprintf(stdout, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
fprintf(stderr, " --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n"); fprintf(stdout, " --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n");
fprintf(stderr, " --prompt-cache-all if specified, saves user input and generations to cache as well.\n"); fprintf(stdout, " --prompt-cache-all if specified, saves user input and generations to cache as well.\n");
fprintf(stderr, " not supported with --interactive or other interactive options\n"); fprintf(stdout, " not supported with --interactive or other interactive options\n");
fprintf(stderr, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n");
fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); fprintf(stdout, " --random-prompt start with a randomized prompt.\n");
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); fprintf(stdout, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stdout, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n"); fprintf(stdout, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
fprintf(stderr, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
fprintf(stderr, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); fprintf(stdout, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
fprintf(stderr, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); fprintf(stdout, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
fprintf(stderr, " --mirostat N use Mirostat sampling.\n"); fprintf(stdout, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); fprintf(stdout, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat); fprintf(stdout, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
fprintf(stderr, " --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta); fprintf(stdout, " --mirostat N use Mirostat sampling.\n");
fprintf(stderr, " --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau); fprintf(stdout, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); fprintf(stdout, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stdout, " --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stdout, " --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stdout, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
fprintf(stderr, " --cfg-negative-prompt PROMPT \n"); fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n");
fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n"); fprintf(stdout, " --no-penalize-nl do not penalize newline token\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp); fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n");
fprintf(stderr, " --perplexity compute perplexity over each ctx window of the prompt\n"); fprintf(stdout, " --temp N temperature (default: %.1f)\n", (double)params.temp);
fprintf(stderr, " --perplexity-lines compute perplexity over each line of the prompt\n"); fprintf(stdout, " --perplexity compute perplexity over each ctx window of the prompt\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); fprintf(stdout, " --perplexity-lines compute perplexity over each line of the prompt\n");
fprintf(stderr, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); fprintf(stdout, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
fprintf(stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
if (llama_mlock_supported()) { if (llama_mlock_supported()) {
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
} }
if (llama_mmap_supported()) { if (llama_mmap_supported()) {
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); fprintf(stdout, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
} }
fprintf(stderr, " --numa attempt optimizations that help on some NUMA systems\n"); fprintf(stdout, " --numa attempt optimizations that help on some NUMA systems\n");
fprintf(stderr, " if run without this previously, it is recommended to drop the system page cache before using this\n"); fprintf(stdout, " if run without this previously, it is recommended to drop the system page cache before using this\n");
fprintf(stderr, " see https://github.com/ggerganov/llama.cpp/issues/1437\n"); fprintf(stdout, " see https://github.com/ggerganov/llama.cpp/issues/1437\n");
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, " -ngl N, --n-gpu-layers N\n"); fprintf(stdout, " -ngl N, --n-gpu-layers N\n");
fprintf(stderr, " number of layers to store in VRAM\n"); fprintf(stdout, " number of layers to store in VRAM\n");
fprintf(stderr, " -ts SPLIT --tensor-split SPLIT\n"); fprintf(stdout, " -ts SPLIT --tensor-split SPLIT\n");
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" ); fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
fprintf(stderr, " -lv, --low-vram don't allocate VRAM scratch buffer\n" ); fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n" );
#endif #endif
fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stdout, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n"); fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stdout, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stdout, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stdout, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n"); fprintf(stdout, "\n");
} }
std::string gpt_random_prompt(std::mt19937 & rng) { std::string gpt_random_prompt(std::mt19937 & rng) {
@ -580,6 +590,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_ctx = params.n_ctx; lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch; lparams.n_batch = params.n_batch;
lparams.n_gqa = params.n_gqa;
lparams.n_gpu_layers = params.n_gpu_layers; lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu; lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split; lparams.tensor_split = params.tensor_split;

View file

@ -27,6 +27,7 @@ struct gpt_params {
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM int32_t n_gpu_layers = 0; // number of layers to store in VRAM
@ -47,7 +48,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float frequency_penalty = 0.00f; // 0.0 = disabled float frequency_penalty = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled float presence_penalty = 0.00f; // 0.0 = disabled
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate

View file

@ -2,57 +2,22 @@ function! Llm()
let url = "http://127.0.0.1:8080/completion" let url = "http://127.0.0.1:8080/completion"
" Save the current cursor position
let save_cursor = getpos('.')
silent! %s/\n/\\n/g
silent! %s/\t/\\t/g
silent! %s/\\n$//
" Get the content of the current buffer " Get the content of the current buffer
let buffer_content = join(getline(1, '$'), "\n") let buffer_content = join(getline(1, '$'), "\n")
" Replace true newlines with "\n"
let buffer_content = substitute(buffer_content, '\n', '\\n', 'g')
" Trim leading/trailing whitespace
let buffer_content = substitute(buffer_content, '^\s\+', '', '')
let buffer_content = substitute(buffer_content, '\s\+$', '', '')
" Create the JSON payload " Create the JSON payload
" can't escape backslash, \n gets replaced as \\n let json_payload = {"temp":0.72,"top_k":100,"top_p":0.73,"repeat_penalty":1.100000023841858,"n_predict":10,"stream": v:false}
let json_payload = '{"prompt":"' . escape(buffer_content, '"/') . '","temp":0.72,"top_k":100,"top_p":0.73,"repeat_penalty":1.100000023841858,"n_predict":10,"stream":false}' let json_payload.prompt = buffer_content
let prompt_tmpfile = tempname()
let response_tmpfile = tempname()
call writefile([json_payload], prompt_tmpfile)
" Define the curl command " Define the curl command
let curl_command = 'curl -k -s -X POST -H "Content-Type: application/json" -o ' . shellescape(response_tmpfile) . ' -d @' . shellescape(prompt_tmpfile) . ' ' . url let curl_command = 'curl -k -s -X POST -H "Content-Type: application/json" -d @- ' . url
silent execute '!'.curl_command let response = system(curl_command, json_encode(json_payload))
let response = join(readfile(response_tmpfile), '')
let start_marker = '{"content":"'
let end_marker = '","generation_settings'
let content_start = stridx(response, start_marker) + len(start_marker)
let content_end = stridx(response, end_marker, content_start)
" Extract the content field from the response " Extract the content field from the response
let content = strpart(response, content_start, content_end - content_start) let content = json_decode(response).content
" Insert the content at the cursor position " Insert the content at the cursor position
call setline(line('.'), getline('.') . content) call setline(line('.'), getline('.') . content)
" Replace newline "\n" strings with actual newlines in the content
silent! %s/\\n/\r/g
" and tabs
silent! %s/\\t/\t/g
" and quote marks for C sources
silent! %s/\\"/\"/g
" Remove the temporary file
call delete(prompt_tmpfile)
call delete(response_tmpfile)
endfunction endfunction
command! Llm call Llm() command! Llm call Llm()

View file

@ -93,8 +93,8 @@ int main(int argc, char ** argv) {
} }
if (params.n_ctx > 2048) { if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);" // TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048
" you are on your own\n", __func__, params.n_ctx); fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx);
} else if (params.n_ctx < 8) { } else if (params.n_ctx < 8) {
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__); fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8; params.n_ctx = 8;
@ -139,17 +139,14 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
} }
// determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters // determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters
// uncomment the "used_mem" line in llama.cpp to see the results // uncomment the "used_mem" line in llama.cpp to see the results
if (params.mem_test) { if (params.mem_test) {
{ {
const std::vector<llama_token> tmp(params.n_batch, llama_token_bos()); fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
}
{ const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
const std::vector<llama_token> tmp = { 0, }; llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads);
llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
} }
llama_print_timings(ctx); llama_print_timings(ctx);

View file

@ -7,7 +7,8 @@
flake-utils.lib.eachDefaultSystem (system: flake-utils.lib.eachDefaultSystem (system:
let let
inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin; inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin;
osSpecific = with pkgs; [ openmpi ] ++ buildInputs = with pkgs; [ openmpi ];
osSpecific = with pkgs; buildInputs ++
( (
if isAarch64 && isDarwin then if isAarch64 && isDarwin then
with pkgs.darwin.apple_sdk_11_0.frameworks; [ with pkgs.darwin.apple_sdk_11_0.frameworks; [
@ -29,18 +30,24 @@
nativeBuildInputs = with pkgs; [ cmake pkgconfig ]; nativeBuildInputs = with pkgs; [ cmake pkgconfig ];
llama-python = llama-python =
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]); pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
in {
packages.default = pkgs.stdenv.mkDerivation {
name = "llama.cpp";
src = ./.;
postPatch = '' postPatch = ''
substituteInPlace ./ggml-metal.m \ substituteInPlace ./ggml-metal.m \
--replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";" --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";"
substituteInPlace ./*.py --replace '/usr/bin/env python' '${llama-python}/bin/python' substituteInPlace ./*.py --replace '/usr/bin/env python' '${llama-python}/bin/python'
''; '';
postInstall = ''
mv $out/bin/main $out/bin/llama
mv $out/bin/server $out/bin/llama-server
'';
cmakeFlags = [ "-DLLAMA_BUILD_SERVER=ON" "-DLLAMA_MPI=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ];
in {
packages.default = pkgs.stdenv.mkDerivation {
name = "llama.cpp";
src = ./.;
postPatch = postPatch;
nativeBuildInputs = nativeBuildInputs; nativeBuildInputs = nativeBuildInputs;
buildInputs = osSpecific; buildInputs = osSpecific;
cmakeFlags = [ "-DLLAMA_BUILD_SERVER=ON" "-DLLAMA_MPI=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ] cmakeFlags = cmakeFlags
++ (if isAarch64 && isDarwin then [ ++ (if isAarch64 && isDarwin then [
"-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1" "-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1"
"-DLLAMA_METAL=ON" "-DLLAMA_METAL=ON"
@ -48,10 +55,19 @@
"-DLLAMA_BLAS=ON" "-DLLAMA_BLAS=ON"
"-DLLAMA_BLAS_VENDOR=OpenBLAS" "-DLLAMA_BLAS_VENDOR=OpenBLAS"
]); ]);
postInstall = '' postInstall = postInstall;
mv $out/bin/main $out/bin/llama meta.mainProgram = "llama";
mv $out/bin/server $out/bin/llama-server };
''; packages.opencl = pkgs.stdenv.mkDerivation {
name = "llama.cpp";
src = ./.;
postPatch = postPatch;
nativeBuildInputs = nativeBuildInputs;
buildInputs = with pkgs; buildInputs ++ [ clblast ];
cmakeFlags = cmakeFlags ++ [
"-DLLAMA_CLBLAST=ON"
];
postInstall = postInstall;
meta.mainProgram = "llama"; meta.mainProgram = "llama";
}; };
apps.llama-server = { apps.llama-server = {

View file

@ -220,7 +220,7 @@ typedef struct {
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
#define WARP_SIZE 32 #define WARP_SIZE 32
#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#define CUDA_ADD_BLOCK_SIZE 256 #define CUDA_ADD_BLOCK_SIZE 256
#define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256
@ -935,12 +935,18 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
uint16_t aux[4]; uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux; const uint8_t * sc = (const uint8_t *)aux;
#if K_QUANTS_PER_ITERATION == 2
uint32_t q32[4];
const uint8_t * q4 = (const uint8_t *)q32;
#else
uint16_t q16[4];
const uint8_t * q4 = (const uint8_t *)q16;
#endif
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const uint8_t * q1 = x[i].qs + q_offset;
const uint8_t * q2 = q1 + 64;
const float * y1 = yy + i*QK_K + y_offset; const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128; const float * y2 = y1 + 128;
@ -953,14 +959,41 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
#if K_QUANTS_PER_ITERATION == 2
const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
const uint32_t * q2 = q1 + 16;
q32[0] = q1[0] & 0x0f0f0f0f;
q32[1] = q1[0] & 0xf0f0f0f0;
q32[2] = q2[0] & 0x0f0f0f0f;
q32[3] = q2[0] & 0xf0f0f0f0;
float4 s = {0.f, 0.f, 0.f, 0.f}; float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0; float smin = 0;
for (int l = 0; l < n; ++l) { for (int l = 0; l < 4; ++l) {
s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4); s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4); s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
} }
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#else
const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
const uint16_t * q2 = q1 + 32;
q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[0] & 0xf0f0;
q16[2] = q2[0] & 0x0f0f;
q16[3] = q2[0] & 0xf0f0;
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 2; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#endif
} }
#else #else
@ -1521,7 +1554,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_K * bq4_K = (const block_q4_K *) vbq; const block_q4_K * bq4_K = (const block_q4_K *) vbq;
const int bq8_offset = QR4_K * (iqs / QI8_1); const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6
float sumf_d = 0.0f; float sumf_d = 0.0f;
float sumf_m = 0.0f; float sumf_m = 0.0f;
@ -1531,11 +1564,20 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]); const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
for (int i = 0; i < QR4_K; ++i) { const uint16_t * scales = (const uint16_t *)bq4_K->scales;
const int isc = bq8_offset + i; uint16_t aux[2];
const int j = bq8_offset/2;
if (j < 2) {
aux[0] = scales[j+0] & 0x3f3f;
aux[1] = scales[j+2] & 0x3f3f;
} else {
aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
}
const uint8_t * sc = (const uint8_t *)aux;
const uint8_t * m = sc + 2;
uint8_t sc, m; for (int i = 0; i < QR4_K; ++i) {
get_scale_min_k4(isc, bq4_K->scales, sc, m);
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
@ -1543,8 +1585,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const int vi = (v >> (4*i)) & 0x0F0F0F0F; const int vi = (v >> (4*i)) & 0x0F0F0F0F;
sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product sumf_d += d8i * (__dp4a(vi, ui, 0) * sc[i]); // SIMD dot product
sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q4_K with sum of q8_1 values sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]); // multiply constant part of q4_K with sum of q8_1 values
} }
return d*sumf_d - dmin*sumf_m; return d*sumf_d - dmin*sumf_m;
@ -1745,11 +1787,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
} }
} }
static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) { static __global__ void mul_mat_p021_f16_f32(
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
const half * x = (const half *) vx; const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z; const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int channel_x = channel / (nchannels_y / nchannels_x);
const int nrows_y = ncols_x; const int nrows_y = ncols_x;
const int nrows_dst = nrows_x; const int nrows_dst = nrows_x;
@ -1765,7 +1811,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
} }
// x is transposed and permuted // x is transposed and permuted
const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
const float xi = __half2float(x[ix]); const float xi = __half2float(x[ix]);
const int row_y = col_x; const int row_y = col_x;
@ -1793,12 +1839,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int channel_stride_x) { const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
const half * x = (const half *) vx; const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z; const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int channel_x = channel / channel_x_divisor;
const int nrows_y = ncols_x; const int nrows_y = ncols_x;
const int nrows_dst = nrows_x; const int nrows_dst = nrows_x;
@ -1815,7 +1862,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
break; break;
} }
const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x; const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const float xi = __half2float(x[ix]); const float xi = __half2float(x[ix]);
const int row_y = col_x; const int row_y = col_x;
@ -2324,20 +2371,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
} }
} }
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { static void ggml_mul_mat_p021_f16_f32_cuda(
const dim3 block_nums(1, nrows_x, nchannels_x); const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
} }
static void ggml_mul_mat_vec_nc_f16_f32_cuda( static void ggml_mul_mat_vec_nc_f16_f32_cuda(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
const int nchannels_x, const int channel_stride_x, cudaStream_t stream) { const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
const dim3 block_nums(1, nrows_x, nchannels_x); const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>> mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x); (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
} }
static void ggml_cpy_f32_f32_cuda( static void ggml_cpy_f32_f32_cuda(
@ -2497,7 +2547,9 @@ static size_t g_scratch_offset = 0;
static int g_device_count = -1; static int g_device_count = -1;
static int g_main_device = 0; static int g_main_device = 0;
#ifndef GGML_CUDA_FORCE_DMMV
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
#endif
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@ -2520,7 +2572,9 @@ void ggml_init_cublas() {
g_tensor_split[id] = total_vram; g_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem; total_vram += prop.totalGlobalMem;
#ifndef GGML_CUDA_FORCE_DMMV
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
#endif
} }
for (int id = 0; id < g_device_count; ++id) { for (int id = 0; id < g_device_count; ++id) {
g_tensor_split[id] /= total_vram; g_tensor_split[id] /= total_vram;
@ -2688,6 +2742,7 @@ inline void ggml_cuda_op_mul(
(void) dst; (void) dst;
(void) src0_ddq_i; (void) src0_ddq_i;
(void) i02; (void) i02;
(void) i1;
} }
inline void ggml_cuda_op_gelu( inline void ggml_cuda_op_gelu(
@ -2815,8 +2870,8 @@ inline void ggml_cuda_op_mul_mat_vec(
#endif #endif
if (use_mul_mat_vec_q) { if (use_mul_mat_vec_q) {
int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1; const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ?
padded_row_size -= padded_row_size % MATRIX_ROW_PADDING; ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
size_t as; size_t as;
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as); void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main); quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
@ -2983,15 +3038,15 @@ inline void ggml_cuda_op_rope(
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low; const int64_t i01_diff = i01_high - i01_low;
const int n_past = ((int32_t *) src1->data)[0]; const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) src1->data)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) src1->data)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) src1->data)[3]; const int n_ctx = ((int32_t *) dst->op_params)[3];
// RoPE alteration for extended context // RoPE alteration for extended context
float freq_base, freq_scale; float freq_base, freq_scale;
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale; const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
@ -3007,6 +3062,7 @@ inline void ggml_cuda_op_rope(
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
} }
(void) src1;
(void) dst; (void) dst;
(void) src0_ddq_i; (void) src0_ddq_i;
(void) src1_ddf_i; (void) src1_ddf_i;
@ -3025,11 +3081,12 @@ inline void ggml_cuda_op_diag_mask_inf(
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t i01_diff = i01_high - i01_low; const int64_t i01_diff = i01_high - i01_low;
const int n_past = ((int32_t *) src1->data)[0]; const int n_past = ((int32_t *) dst->op_params)[0];
// compute // compute
diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
(void) src1;
(void) dst; (void) dst;
(void) src0_ddq_i; (void) src0_ddq_i;
(void) src1_ddf_i; (void) src1_ddf_i;
@ -3097,6 +3154,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
const int64_t ne11 = use_src1 ? src1->ne[1] : 1; const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
const int64_t ne12 = use_src1 ? src1->ne[2] : 1; const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
const int64_t ne13 = use_src1 ? src1->ne[3] : 1; const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
GGML_ASSERT(ne03 == ne13);
const int64_t ne0 = dst->ne[0]; const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1]; const int64_t ne1 = dst->ne[1];
@ -3108,12 +3168,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
// strides for iteration over dims 3 and 2 // strides for iteration over dims 3 and 2
const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03; const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1; const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
const int64_t src0_stride = ne00 * ne01 * stride_mod; const int64_t src0_stride = ne00 * ne01 * stride_mod;
const int64_t src1_stride = ne10 * ne11 * stride_mod; const int64_t src1_stride = ne10 * ne11 * stride_mod;
const int64_t dst_stride = ne0 * ne1 * stride_mod; const int64_t dst_stride = ne0 * ne1 * stride_mod;
const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
const int64_t i03_max = flatten_rows ? 1 : ne03;
const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
GGML_ASSERT(!(flatten_rows && ne02 < ne12));
const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_ts = ggml_type_size(src0->type);
const size_t src0_bs = ggml_blck_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type);
@ -3130,6 +3197,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE); dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
GGML_ASSERT(!(split && ne02 < ne12));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
@ -3166,7 +3234,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1]; row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
} else { } else {
row_low = 0; row_low = 0;
row_high = nrows0; row_high = nrows0*i02_divisor;
} }
if (row_low == row_high) { if (row_low == row_high) {
continue; continue;
@ -3214,16 +3282,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]); dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
} }
const int64_t i03_max = flatten_rows ? 1 : ne03;
const int64_t i02_max = flatten_rows ? 1 : ne02;
const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
for (int64_t i03 = 0; i03 < i03_max; i03++) { for (int64_t i03 = 0; i03 < i03_max; i03++) {
const int64_t i13 = i03 % ne13; const int64_t i13 = i03 % ne13;
for (int64_t i02 = 0; i02 < i02_max; i02++) { for (int64_t i02 = 0; i02 < i02_max; i02++) {
const int64_t i12 = i02 % ne12; const int64_t i12 = i02 % ne12;
const int64_t i0 = i03*ne02 + i02; const int64_t i0 = i03*i02_max + i02;
// i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
const int64_t i0_offset_low = row_low/rows_per_iter; const int64_t i0_offset_low = row_low/rows_per_iter;
@ -3257,8 +3321,8 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
const int64_t i11 = i13*ne12 + i12; const int64_t i11 = i13*ne12 + i12;
// for split tensors the data begins at i0 == i0_offset_low // for split tensors the data begins at i0 == i0_offset_low
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
float * src1_ddf_i = src1_ddf[id] + i11*src1_stride; float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
@ -3299,11 +3363,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
} }
} }
if (!src0_on_device || !src0_is_contiguous) { if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
if (src0_is_f32) { if (src0_is_f32) {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
} else { } else {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
} }
} }
@ -3457,6 +3521,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne12 = src1->ne[2];
CUDA_CHECK(cudaSetDevice(g_main_device)); CUDA_CHECK(cudaSetDevice(g_main_device));
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
@ -3469,7 +3535,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main); ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
} }
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@ -3483,6 +3549,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne12 = src1->ne[2];
const int64_t nb01 = src0->nb[1]; const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2]; const int64_t nb02 = src0->nb[2];
@ -3501,7 +3569,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int row_stride_x = nb01 / sizeof(half); const int row_stride_x = nb01 / sizeof(half);
const int channel_stride_x = nb02 / sizeof(half); const int channel_stride_x = nb02 / sizeof(half);
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main); ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
} }
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -3642,7 +3710,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
size_t size = ggml_nbytes_split(tensor, nrows_split); size_t size = ggml_nbytes_split(tensor, nrows_split);
const size_t original_size = size; const size_t original_size = size;
// pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
if (ne0 % MATRIX_ROW_PADDING != 0) { if (ne0 % MATRIX_ROW_PADDING != 0) {
size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING) size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
* ggml_type_size(tensor->type)/ggml_blck_size(tensor->type); * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
@ -3658,7 +3726,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
} }
CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice));
extra->data_device[id] = buf; extra->data_device[id] = buf;
@ -3738,7 +3806,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
size_t offset = 0; size_t offset = 0;
if (tensor->op == GGML_OP_VIEW) { if (tensor->op == GGML_OP_VIEW) {
memcpy(&offset, tensor->src[2]->data, sizeof(size_t)); memcpy(&offset, tensor->op_params, sizeof(size_t));
} }
extra = ggml_cuda_alloc_temp_tensor_extra(); extra = ggml_cuda_alloc_temp_tensor_extra();
extra->data_device[g_main_device] = src0_ddc + offset; extra->data_device[g_main_device] = src0_ddc + offset;

View file

@ -42,6 +42,7 @@ struct ggml_metal_context {
id<MTLComputePipelineState> pipeline_##name id<MTLComputePipelineState> pipeline_##name
GGML_METAL_DECL_KERNEL(add); GGML_METAL_DECL_KERNEL(add);
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul);
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
GGML_METAL_DECL_KERNEL(scale); GGML_METAL_DECL_KERNEL(scale);
@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
GGML_METAL_ADD_KERNEL(add); GGML_METAL_ADD_KERNEL(add);
GGML_METAL_ADD_KERNEL(add_row);
GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul);
GGML_METAL_ADD_KERNEL(mul_row); GGML_METAL_ADD_KERNEL(mul_row);
GGML_METAL_ADD_KERNEL(scale); GGML_METAL_ADD_KERNEL(scale);
@ -464,10 +466,16 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_add_row];
} else {
[encoder setComputePipelineState:ctx->pipeline_add]; [encoder setComputePipelineState:ctx->pipeline_add];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst);
@ -577,7 +585,7 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
const int n_past = ((int32_t *)(src1->data))[0]; const int n_past = ((int32_t *)(dst->op_params))[0];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -842,9 +850,10 @@ void ggml_metal_graph_compute(
GGML_ASSERT((src0t == GGML_TYPE_F32)); GGML_ASSERT((src0t == GGML_TYPE_F32));
const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past); const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
const int n_head = ((int32_t *) src1->data)[1]; const int n_head = ((int32_t *) dst->op_params)[1];
const float max_bias = ((float *) src1->data)[2]; float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
if (__builtin_popcount(n_head) != 1) { if (__builtin_popcount(n_head) != 1) {
GGML_ASSERT(false && "only power-of-two n_head implemented"); GGML_ASSERT(false && "only power-of-two n_head implemented");
@ -882,15 +891,14 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
const int n_dims = ((int32_t *) src1->data)[1]; const int n_past = ((int32_t *) dst->op_params)[0];
const int mode = ((int32_t *) src1->data)[2]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_past = ((int32_t *)(src1->data))[0];
float freq_base; float freq_base;
float freq_scale; float freq_scale;
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
[encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -919,7 +927,9 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];

View file

@ -67,6 +67,17 @@ kernel void kernel_add(
dst[tpig] = src0[tpig] + src1[tpig]; dst[tpig] = src0[tpig] + src1[tpig];
} }
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % ne00];
}
kernel void kernel_mul( kernel void kernel_mul(
device const float * src0, device const float * src0,
device const float * src1, device const float * src1,

616
ggml.c

File diff suppressed because it is too large Load diff

4
ggml.h
View file

@ -199,6 +199,7 @@
#define GGML_MAX_CONTEXTS 64 #define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 6 #define GGML_MAX_SRC 6
#define GGML_MAX_NAME 48 #define GGML_MAX_NAME 48
#define GGML_MAX_OP_PARAMS 32
#define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_N_THREADS 4
@ -418,6 +419,9 @@ extern "C" {
// compute data // compute data
enum ggml_op op; enum ggml_op op;
// op params - allocated as int32_t for alignment
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(uint32_t)];
bool is_param; bool is_param;
struct ggml_tensor * grad; struct ggml_tensor * grad;

236
llama.cpp
View file

@ -74,6 +74,7 @@ enum e_model {
MODEL_13B, MODEL_13B,
MODEL_30B, MODEL_30B,
MODEL_65B, MODEL_65B,
MODEL_70B,
}; };
static const size_t kB = 1024; static const size_t kB = 1024;
@ -105,18 +106,18 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
} }
// //
// memory sizes // memory sizes (calculated for n_batch == 512)
// //
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx) static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
{ {
static std::map<e_model, size_t> k_sizes = { static std::map<e_model, size_t> k_sizes = {
/* empirical scaling, still a guess */ { MODEL_3B, ((size_t) n_ctx / 16ull + 92ull) * MB },
{ MODEL_3B, ((size_t) n_ctx / 16ull + 128ull) * MB }, { MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB },
{ MODEL_7B, ((size_t) n_ctx / 16ull + 256ull) * MB }, { MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
{ MODEL_13B, ((size_t) n_ctx / 12ull + 256ull) * MB }, { MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
{ MODEL_30B, ((size_t) n_ctx / 10ull + 256ull) * MB }, { MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
{ MODEL_65B, ((size_t) n_ctx / 8ull + 512ull) * MB }, { MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
}; };
return k_sizes; return k_sizes;
} }
@ -124,38 +125,26 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1() static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
{ {
static std::map<e_model, size_t> k_sizes = { static std::map<e_model, size_t> k_sizes = {
{ MODEL_3B, 256ull * MB }, { MODEL_3B, 128ull * MB },
{ MODEL_7B, 512ull * MB }, { MODEL_7B, 160ull * MB },
{ MODEL_13B, 512ull * MB }, { MODEL_13B, 192ull * MB },
{ MODEL_30B, 512ull * MB }, { MODEL_30B, 256ull * MB },
{ MODEL_65B, 1024ull * MB }, { MODEL_65B, 384ull * MB }, // guess
{ MODEL_70B, 304ull * MB },
}; };
return k_sizes; return k_sizes;
} }
// 2*n_embd*n_ctx*n_layer*sizeof(float16) // used to store the compute graph tensors + non-scratch data
static const std::map<e_model, size_t> & MEM_REQ_KV_SELF() static const std::map<e_model, size_t> & MEM_REQ_EVAL()
{ {
static std::map<e_model, size_t> k_sizes = { static std::map<e_model, size_t> k_sizes = {
{ MODEL_3B, 682ull * MB }, { MODEL_3B, 8ull * MB },
{ MODEL_7B, 1026ull * MB }, { MODEL_7B, 10ull * MB },
{ MODEL_13B, 1608ull * MB }, { MODEL_13B, 12ull * MB },
{ MODEL_30B, 3124ull * MB }, { MODEL_30B, 16ull * MB },
{ MODEL_65B, 5120ull * MB }, { MODEL_65B, 24ull * MB }, // guess
}; { MODEL_70B, 24ull * MB },
return k_sizes;
}
// this is mostly needed for temporary mul_mat buffers to dequantize the data
// not actually needed if BLAS is disabled
static const std::map<e_model, size_t> & MEM_REQ_EVAL(int n_ctx)
{
static std::map<e_model, size_t> k_sizes = {
{ MODEL_3B, ((size_t) n_ctx / 256ull + 512ull) * MB },
{ MODEL_7B, ((size_t) n_ctx / 256ull + 768ull) * MB },
{ MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB },
{ MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB },
{ MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB },
}; };
return k_sizes; return k_sizes;
} }
@ -170,6 +159,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
{ MODEL_13B, 640ull * kB }, { MODEL_13B, 640ull * kB },
{ MODEL_30B, 768ull * kB }, { MODEL_30B, 768ull * kB },
{ MODEL_65B, 1536ull * kB }, { MODEL_65B, 1536ull * kB },
{ MODEL_70B, 1536ull * kB }, // TODO (likely can be reduced)
}; };
return k_sizes; return k_sizes;
} }
@ -184,6 +174,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
{ MODEL_13B, 160ull }, { MODEL_13B, 160ull },
{ MODEL_30B, 208ull }, { MODEL_30B, 208ull },
{ MODEL_65B, 416ull }, { MODEL_65B, 416ull },
{ MODEL_70B, 416ull }, // TODO (likely can be reduced)
}; };
return k_sizes; return k_sizes;
} }
@ -195,16 +186,42 @@ struct llama_hparams {
uint32_t n_embd = 4096; uint32_t n_embd = 4096;
uint32_t n_mult = 256; uint32_t n_mult = 256;
uint32_t n_head = 32; uint32_t n_head = 32;
uint32_t n_head_kv = 32;
uint32_t n_layer = 32; uint32_t n_layer = 32;
uint32_t n_rot = 64; uint32_t n_rot = 64;
// LLaMAv2
// TODO: load from model data hparams
float f_ffn_mult = 1.0f;
float rope_freq_base = 10000.0f; float rope_freq_base = 10000.0f;
float rope_freq_scale = 1.0f; float rope_freq_scale = 1.0f;
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
bool operator!=(const llama_hparams & other) const { bool operator!=(const llama_hparams & other) const {
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
}
uint32_t n_gqa() const {
return n_head/n_head_kv;
}
uint32_t n_embd_head() const {
return n_embd/n_head;
}
uint32_t n_embd_gqa() const {
return n_embd/n_gqa();
}
size_t kv_size() const {
size_t result = 2ull;
result *= (size_t) n_embd_gqa();
result *= (size_t) n_ctx;
result *= (size_t) n_layer;
result *= sizeof(ggml_fp16_t);
return result;
} }
}; };
@ -520,6 +537,10 @@ struct llama_file_loader {
hparams.n_layer = file.read_u32(); hparams.n_layer = file.read_u32();
hparams.n_rot = file.read_u32(); hparams.n_rot = file.read_u32();
hparams.ftype = (enum llama_ftype) file.read_u32(); hparams.ftype = (enum llama_ftype) file.read_u32();
// LLaMAv2
// TODO: read from header
hparams.n_head_kv = hparams.n_head;
} }
void read_vocab() { void read_vocab() {
vocab.id_to_token.resize(hparams.n_vocab); vocab.id_to_token.resize(hparams.n_vocab);
@ -818,7 +839,7 @@ static bool kv_cache_init(
ggml_type wtype, ggml_type wtype,
int n_ctx, int n_ctx,
int n_gpu_layers) { int n_gpu_layers) {
const int n_embd = hparams.n_embd; const int n_embd = hparams.n_embd_gqa();
const int n_layer = hparams.n_layer; const int n_layer = hparams.n_layer;
const int64_t n_mem = n_layer*n_ctx; const int64_t n_mem = n_layer*n_ctx;
@ -862,6 +883,7 @@ struct llama_context_params llama_context_default_params() {
/*.seed =*/ LLAMA_DEFAULT_SEED, /*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512, /*.n_ctx =*/ 512,
/*.n_batch =*/ 512, /*.n_batch =*/ 512,
/*.n_gqa =*/ 1,
/*.gpu_layers =*/ 0, /*.gpu_layers =*/ 0,
/*.main_gpu =*/ 0, /*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr, /*.tensor_split =*/ nullptr,
@ -981,6 +1003,7 @@ static const char *llama_model_type_name(e_model type) {
case MODEL_13B: return "13B"; case MODEL_13B: return "13B";
case MODEL_30B: return "30B"; case MODEL_30B: return "30B";
case MODEL_65B: return "65B"; case MODEL_65B: return "65B";
case MODEL_70B: return "70B";
default: LLAMA_ASSERT(false); default: LLAMA_ASSERT(false);
} }
} }
@ -991,6 +1014,7 @@ static void llama_model_load_internal(
llama_vocab & vocab, llama_vocab & vocab,
int n_ctx, int n_ctx,
int n_batch, int n_batch,
int n_gqa,
int n_gpu_layers, int n_gpu_layers,
int main_gpu, int main_gpu,
const float * tensor_split, const float * tensor_split,
@ -1012,6 +1036,7 @@ static void llama_model_load_internal(
model.hparams = ml->file_loader->hparams; model.hparams = ml->file_loader->hparams;
model.n_gpu_layers = n_gpu_layers; model.n_gpu_layers = n_gpu_layers;
llama_file_version file_version = ml->file_loader->file_version; llama_file_version file_version = ml->file_loader->file_version;
auto & hparams = model.hparams; auto & hparams = model.hparams;
{ {
@ -1031,26 +1056,42 @@ static void llama_model_load_internal(
hparams.n_ctx = n_ctx; hparams.n_ctx = n_ctx;
// LLaMAv2
// TODO: temporary until GGUF
LLAMA_ASSERT(hparams.n_head % n_gqa == 0);
hparams.n_head_kv = hparams.n_head / n_gqa;
if (model.type == e_model::MODEL_65B && n_gqa == 8) {
fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
model.type = e_model::MODEL_70B;
hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
}
hparams.rope_freq_base = rope_freq_base; hparams.rope_freq_base = rope_freq_base;
hparams.rope_freq_scale = rope_freq_scale; hparams.rope_freq_scale = rope_freq_scale;
} }
const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199
const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3;
const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw;
const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
//const uint32_t n_ff = 28672;
{ {
LLAMA_LOG_INFO("%s: format = %s", __func__, llama_file_version_name(file_version)); fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
LLAMA_LOG_INFO("%s: n_vocab = %u", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_ctx = %u", __func__, hparams.n_ctx); fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx);
LLAMA_LOG_INFO("%s: n_embd = %u", __func__, hparams.n_embd); fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_mult = %u", __func__, hparams.n_mult); fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
LLAMA_LOG_INFO("%s: n_head = %u", __func__, hparams.n_head); fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
LLAMA_LOG_INFO("%s: n_layer = %u", __func__, hparams.n_layer); fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
LLAMA_LOG_INFO("%s: n_rot = %u", __func__, hparams.n_rot); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: freq_base = %.1f", __func__, hparams.rope_freq_base); fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
LLAMA_LOG_INFO("%s: freq_scale = %g", __func__, hparams.rope_freq_scale); fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa());
LLAMA_LOG_INFO("%s: ftype = %u (%s)", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
LLAMA_LOG_INFO("%s: n_ff = %u", __func__, n_ff); fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
LLAMA_LOG_INFO("%s: model size = %s", __func__, llama_model_type_name(model.type)); fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype));
fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type));
} }
if (file_version < LLAMA_FILE_VERSION_GGJT_V2) { if (file_version < LLAMA_FILE_VERSION_GGJT_V2) {
@ -1120,6 +1161,7 @@ static void llama_model_load_internal(
size_t vram_scratch = 0; size_t vram_scratch = 0;
{ {
const uint32_t n_embd = hparams.n_embd; const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd_gqa = hparams.n_embd_gqa();
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_vocab = hparams.n_vocab; const uint32_t n_vocab = hparams.n_vocab;
@ -1170,8 +1212,8 @@ static void llama_model_load_internal(
layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend);
layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split);
layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend_split); layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd_gqa}, backend_split);
layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend_split); layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split);
layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split);
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
@ -1201,11 +1243,11 @@ static void llama_model_load_internal(
mmapped_size - vram_weights + // weights in VRAM not in memory mmapped_size - vram_weights + // weights in VRAM not in memory
MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) + MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
MEM_REQ_SCRATCH1().at(model.type) + MEM_REQ_SCRATCH1().at(model.type) +
MEM_REQ_EVAL(hparams.n_ctx).at(model.type); MEM_REQ_EVAL().at(model.type);
// this is the memory required by one llama_state // this is the memory required by one llama_state
const size_t mem_required_state = const size_t mem_required_state =
scale*MEM_REQ_KV_SELF().at(model.type); scale*hparams.kv_size();
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)", __func__, LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)", __func__,
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
@ -1246,7 +1288,7 @@ static void llama_model_load_internal(
LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option", __func__); LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option", __func__);
} else { } else {
LLAMA_LOG_INFO("%s: offloading v cache to GPU", __func__); LLAMA_LOG_INFO("%s: offloading v cache to GPU", __func__);
vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; vram_kv_cache += hparams.kv_size() / 2;
} }
} }
if (n_gpu_layers > (int) hparams.n_layer + 2) { if (n_gpu_layers > (int) hparams.n_layer + 2) {
@ -1254,7 +1296,7 @@ static void llama_model_load_internal(
LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option", __func__); LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option", __func__);
} else { } else {
LLAMA_LOG_INFO("%s: offloading k cache to GPU", __func__); LLAMA_LOG_INFO("%s: offloading k cache to GPU", __func__);
vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; vram_kv_cache += hparams.kv_size() / 2;
} }
} }
#elif defined(GGML_USE_CLBLAST) #elif defined(GGML_USE_CLBLAST)
@ -1302,6 +1344,7 @@ static bool llama_model_load(
llama_vocab & vocab, llama_vocab & vocab,
int n_ctx, int n_ctx,
int n_batch, int n_batch,
int n_gqa,
int n_gpu_layers, int n_gpu_layers,
int main_gpu, int main_gpu,
const float * tensor_split, const float * tensor_split,
@ -1315,7 +1358,7 @@ static bool llama_model_load(
llama_progress_callback progress_callback, llama_progress_callback progress_callback,
void *progress_callback_user_data) { void *progress_callback_user_data) {
try { try {
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
return true; return true;
} catch (const std::exception & err) { } catch (const std::exception & err) {
@ -1359,17 +1402,22 @@ static bool llama_eval_internal(
LLAMA_ASSERT(!!kv_self.ctx); LLAMA_ASSERT(!!kv_self.ctx);
const int n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer; const int64_t n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx; const int64_t n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head; const int64_t n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab; const int64_t n_head_kv = hparams.n_head_kv;
const int n_rot = hparams.n_embd/hparams.n_head; const int64_t n_embd_head = hparams.n_embd_head();
const int n_gpu_layers = model.n_gpu_layers; const int64_t n_vocab = hparams.n_vocab;
const int64_t n_embd_gqa = hparams.n_embd_gqa();
LLAMA_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base; const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale; const float freq_scale = hparams.rope_freq_scale;
const int n_gpu_layers = model.n_gpu_layers;
auto & mem_per_token = lctx.mem_per_token; auto & mem_per_token = lctx.mem_per_token;
auto & buf_compute = lctx.buf_compute; auto & buf_compute = lctx.buf_compute;
@ -1467,11 +1515,11 @@ static bool llama_eval_internal(
offload_func_kq(tmpq); offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq"); ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale); struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur); offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur"); ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale); struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur); offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur"); ggml_set_name(Qcur, "Qcur");
@ -1483,17 +1531,17 @@ static bool llama_eval_internal(
offload_func_v(tmpv); offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv"); ggml_set_name(tmpv, "tmpv");
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N)); struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
offload_func_v(Vcur); offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur"); ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
offload_func_kq(k); offload_func_kq(k);
ggml_set_name(k, "k"); ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v), ( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
offload_func_v(v); offload_func_v(v);
ggml_set_name(v, "v"); ggml_set_name(v, "v");
@ -1512,8 +1560,8 @@ static bool llama_eval_internal(
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctx0, ggml_permute(ctx0,
ggml_reshape_3d(ctx0, ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa),
n_embd/n_head, n_head, n_past + N), n_embd_head, n_head_kv, n_past + N),
0, 2, 1, 3); 0, 2, 1, 3);
offload_func_kq(K); offload_func_kq(K);
ggml_set_name(K, "K"); ggml_set_name(K, "K");
@ -1523,9 +1571,9 @@ static bool llama_eval_internal(
offload_func_kq(KQ); offload_func_kq(KQ);
ggml_set_name(KQ, "KQ"); ggml_set_name(KQ, "KQ");
// KQ_scaled = KQ / sqrt(n_embd/n_head) // KQ_scaled = KQ / sqrt(n_embd_head)
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)"); ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
// KQ_scaled shape [n_past + N, N, n_head, 1] // KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
@ -1545,10 +1593,10 @@ static bool llama_eval_internal(
// split cached V into n_head heads // split cached V into n_head heads
struct ggml_tensor * V = struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v, ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd/n_head, n_head, n_past + N, n_embd_head, n_head_kv,
n_ctx*ggml_element_size(kv_self.v), n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, n_ctx*ggml_element_size(kv_self.v)*n_embd_head,
il*n_ctx*ggml_element_size(kv_self.v)*n_embd); n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il);
offload_func_v(V); offload_func_v(V);
ggml_set_name(V, "V"); ggml_set_name(V, "V");
@ -1560,7 +1608,7 @@ static bool llama_eval_internal(
// make V contiguous in memory to speed up the matmul, however we waste time on the copy // make V contiguous in memory to speed up the matmul, however we waste time on the copy
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
// is there a better way? // is there a better way?
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head)); struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif #endif
@ -1754,10 +1802,12 @@ static bool llama_eval_internal(
} }
#if 0 #if 0
printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__, printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0, ggml_used_mem(ctx0)/1024.0/1024.0,
lctx.get_buf_max_mem(0)/1024.0/1024.0, lctx.get_buf_max_mem(0)/1024.0/1024.0,
lctx.get_buf_max_mem(1)/1024.0/1024.0); lctx.get_buf_max_mem(1)/1024.0/1024.0,
lctx.work_buffer.size()/1024.0/1024.0,
n_past, N);
#endif #endif
ggml_free(ctx0); ggml_free(ctx0);
@ -2548,16 +2598,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
} else { } else {
new_type = quantized_type; new_type = quantized_type;
#ifdef GGML_USE_K_QUANTS #ifdef GGML_USE_K_QUANTS
bool convert_incompatible_tensor = false;
if (quantized_type == GGML_TYPE_Q2_K || quantized_type == GGML_TYPE_Q3_K || quantized_type == GGML_TYPE_Q4_K ||
quantized_type == GGML_TYPE_Q5_K || quantized_type == GGML_TYPE_Q6_K) {
int nx = tensor.ne.at(0);
int ny = tensor.ne.at(1);
if (nx % QK_K != 0 || ny % QK_K != 0) {
LLAMA_LOG_INFO("\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.",nx,ny,QK_K);
convert_incompatible_tensor = true;
}
}
if (tensor.name == "output.weight") { if (tensor.name == "output.weight") {
int nx = tensor.ne.at(0); int nx = tensor.ne.at(0);
int ny = tensor.ne.at(1); int ny = tensor.ne.at(1);
@ -2583,6 +2623,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
} }
bool convert_incompatible_tensor = false;
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) {
int nx = tensor.ne.at(0);
int ny = tensor.ne.at(1);
if (nx % QK_K != 0 || ny % QK_K != 0) {
fprintf(stderr, "\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K);
convert_incompatible_tensor = true;
}
}
if (convert_incompatible_tensor) { if (convert_incompatible_tensor) {
if (tensor.name == "output.weight") { if (tensor.name == "output.weight") {
new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing.
@ -2609,7 +2659,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
f32_data = (float *) f32_conv_buf.addr; f32_data = (float *) f32_conv_buf.addr;
} }
printf("quantizing .. "); printf("quantizing to %s .. ", ggml_type_name(new_type));
fflush(stdout); fflush(stdout);
work.resize(nelements * 4); // upper bound on size work.resize(nelements * 4); // upper bound on size
@ -2712,7 +2762,7 @@ struct llama_model * llama_load_model_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers, if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers,
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram, params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
params.progress_callback_user_data)) { params.progress_callback_user_data)) {
@ -2790,7 +2840,7 @@ struct llama_context * llama_new_context_with_model(
ctx->embedding.resize(hparams.n_embd); ctx->embedding.resize(hparams.n_embd);
} }
ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type)); ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type));
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
@ -2814,7 +2864,7 @@ struct llama_context * llama_new_context_with_model(
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
LLAMA_LOG_INFO("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); fprintf(stderr, "%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
#define LLAMA_METAL_CHECK_BUF(result) \ #define LLAMA_METAL_CHECK_BUF(result) \
if (!(result)) { \ if (!(result)) { \

View file

@ -93,6 +93,7 @@ extern "C" {
uint32_t seed; // RNG seed, -1 for random uint32_t seed; // RNG seed, -1 for random
int32_t n_ctx; // text context int32_t n_ctx; // text context
int32_t n_batch; // prompt processing batch size int32_t n_batch; // prompt processing batch size
int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
int32_t n_gpu_layers; // number of layers to store in VRAM int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors int32_t main_gpu; // the GPU that is used for scratch and small tensors