Merge remote-tracking branch 'origin/master' into prompt-array
This commit is contained in:
commit
88535ed036
24 changed files with 1612 additions and 473 deletions
12
README.md
12
README.md
|
@ -39,6 +39,7 @@ Last revision compatible with the old format: [dadbed9](https://github.com/ggerg
|
||||||
<li><a href="#memorydisk-requirements">Memory/Disk Requirements</a></li>
|
<li><a href="#memorydisk-requirements">Memory/Disk Requirements</a></li>
|
||||||
<li><a href="#quantization">Quantization</a></li>
|
<li><a href="#quantization">Quantization</a></li>
|
||||||
<li><a href="#interactive-mode">Interactive mode</a></li>
|
<li><a href="#interactive-mode">Interactive mode</a></li>
|
||||||
|
<li><a href="#constrained-output-with-grammars">Constrained output with grammars</a></li>
|
||||||
<li><a href="#instruction-mode-with-alpaca">Instruction mode with Alpaca</a></li>
|
<li><a href="#instruction-mode-with-alpaca">Instruction mode with Alpaca</a></li>
|
||||||
<li><a href="#using-openllama">Using OpenLLaMA</a></li>
|
<li><a href="#using-openllama">Using OpenLLaMA</a></li>
|
||||||
<li><a href="#using-gpt4all">Using GPT4All</a></li>
|
<li><a href="#using-gpt4all">Using GPT4All</a></li>
|
||||||
|
@ -604,6 +605,16 @@ PROMPT_TEMPLATE=./prompts/chat-with-bob.txt PROMPT_CACHE_FILE=bob.prompt.bin \
|
||||||
CHAT_SAVE_DIR=./chat/bob ./examples/chat-persistent.sh
|
CHAT_SAVE_DIR=./chat/bob ./examples/chat-persistent.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Constrained output with grammars
|
||||||
|
|
||||||
|
`llama.cpp` supports grammars to constrain model output. For example, you can force the model to output JSON only:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./main -m ./models/13B/ggml-model-q4_0.gguf -n 256 --grammar-file grammars/json.gbnf -p 'Request: schedule a call at 8pm; Command:'
|
||||||
|
```
|
||||||
|
|
||||||
|
The `grammars/` folder contains a handful of sample grammars. To write your own, check out the [GBNF Guide](./grammars/README.md).
|
||||||
|
|
||||||
### Instruction mode with Alpaca
|
### Instruction mode with Alpaca
|
||||||
|
|
||||||
1. First, download the `ggml` Alpaca model into the `./models` folder
|
1. First, download the `ggml` Alpaca model into the `./models` folder
|
||||||
|
@ -885,3 +896,4 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /
|
||||||
- [BLIS](./docs/BLIS.md)
|
- [BLIS](./docs/BLIS.md)
|
||||||
- [Performance troubleshooting](./docs/token_generation_performance_tips.md)
|
- [Performance troubleshooting](./docs/token_generation_performance_tips.md)
|
||||||
- [GGML tips & tricks](https://github.com/ggerganov/llama.cpp/wiki/GGML-Tips-&-Tricks)
|
- [GGML tips & tricks](https://github.com/ggerganov/llama.cpp/wiki/GGML-Tips-&-Tricks)
|
||||||
|
- [GBNF grammars](./grammars/README.md)
|
||||||
|
|
|
@ -289,7 +289,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_batch = std::stoi(argv[i]);
|
params.n_batch = std::stoi(argv[i]);
|
||||||
params.n_batch = std::min(512, params.n_batch);
|
|
||||||
} else if (arg == "--keep") {
|
} else if (arg == "--keep") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -388,11 +387,11 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
#else
|
#else
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
} else if (arg == "--mul-mat-q" || arg == "-mmq") {
|
} else if (arg == "--no-mul-mat-q" || arg == "-nommq") {
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
params.mul_mat_q = true;
|
params.mul_mat_q = false;
|
||||||
#else
|
#else
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n");
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
} else if (arg == "--low-vram" || arg == "-lv") {
|
} else if (arg == "--low-vram" || arg == "-lv") {
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
@ -602,9 +601,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
||||||
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
|
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
|
||||||
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
|
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
|
||||||
fprintf(stdout, " -mmq, --mul-mat-q use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" );
|
fprintf(stdout, " -nommq, --no-mul-mat-q\n");
|
||||||
fprintf(stdout, " Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" );
|
fprintf(stdout, " use cuBLAS instead of custom mul_mat_q CUDA kernels.\n");
|
||||||
fprintf(stdout, " is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" );
|
fprintf(stdout, " Not recommended since this is both slower and uses more VRAM.\n");
|
||||||
#endif
|
#endif
|
||||||
fprintf(stdout, " --mtest compute maximum memory usage\n");
|
fprintf(stdout, " --mtest compute maximum memory usage\n");
|
||||||
fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n");
|
fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n");
|
||||||
|
|
|
@ -68,7 +68,7 @@ struct gpt_params {
|
||||||
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
|
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
|
||||||
|
|
||||||
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
|
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
|
||||||
bool mul_mat_q = false; // if true, use experimental mul_mat_q kernels
|
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
|
||||||
bool memory_f16 = true; // use f16 instead of f32 for memory kv
|
bool memory_f16 = true; // use f16 instead of f32 for memory kv
|
||||||
bool random_prompt = false; // do not randomize prompt if none provided
|
bool random_prompt = false; // do not randomize prompt if none provided
|
||||||
bool use_color = false; // use color to distinguish generations and inputs
|
bool use_color = false; // use color to distinguish generations and inputs
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import sys, struct, math, argparse
|
import sys, struct, math, argparse, warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
|
warnings.filterwarnings('error')
|
||||||
|
|
||||||
# Note: Does not support GGML_QKK_64
|
# Note: Does not support GGML_QKK_64
|
||||||
QK_K = 256
|
QK_K = 256
|
||||||
# Items here are (block size, type size)
|
# Items here are (block size, type size)
|
||||||
|
@ -215,15 +217,10 @@ class GGMLToGGUF:
|
||||||
if self.vocab_override is not None:
|
if self.vocab_override is not None:
|
||||||
vo = self.vocab_override
|
vo = self.vocab_override
|
||||||
print('* Adding vocab item(s)')
|
print('* Adding vocab item(s)')
|
||||||
for (idx, vitem) in enumerate(vo.all_tokens()):
|
for (idx, (vbytes, score, ttype)) in enumerate(vo.all_tokens()):
|
||||||
if len(vitem) == 3:
|
tokens.append(vbytes)
|
||||||
tokens.append(vitem[0])
|
scores.append(score)
|
||||||
scores.append(vitem[1])
|
toktypes.append(ttype)
|
||||||
toktypes.append(vitem[2])
|
|
||||||
else:
|
|
||||||
# Maybe try to guess the token type here?
|
|
||||||
tokens.append(vitem[0])
|
|
||||||
scores.append(vitem[1])
|
|
||||||
assert len(tokens) == hp.n_vocab, f'Override vocab has a different number of items than hyperparameters - override = {len(tokens)} but n_vocab={hp.n_vocab}'
|
assert len(tokens) == hp.n_vocab, f'Override vocab has a different number of items than hyperparameters - override = {len(tokens)} but n_vocab={hp.n_vocab}'
|
||||||
gguf_writer.add_token_list(tokens)
|
gguf_writer.add_token_list(tokens)
|
||||||
gguf_writer.add_token_scores(scores)
|
gguf_writer.add_token_scores(scores)
|
||||||
|
@ -231,13 +228,24 @@ class GGMLToGGUF:
|
||||||
gguf_writer.add_token_types(toktypes)
|
gguf_writer.add_token_types(toktypes)
|
||||||
return
|
return
|
||||||
print(f'* Adding {hp.n_vocab} vocab item(s)')
|
print(f'* Adding {hp.n_vocab} vocab item(s)')
|
||||||
|
assert len(self.model.vocab.items) >= 3, 'Cannot handle unexpectedly short model vocab'
|
||||||
for (tokid, (vbytes, vscore)) in enumerate(self.model.vocab.items):
|
for (tokid, (vbytes, vscore)) in enumerate(self.model.vocab.items):
|
||||||
tt = 1 # Normal
|
tt = 1 # Normal
|
||||||
if len(vbytes) == 0:
|
# Special handling for UNK, BOS, EOS tokens.
|
||||||
|
if tokid <= 2:
|
||||||
|
if tokid == 0:
|
||||||
|
vbytes = b'<unk>'
|
||||||
|
tt = 2
|
||||||
|
elif tokid == 1:
|
||||||
|
vbytes = b'<s>'
|
||||||
|
tt = 3
|
||||||
|
else:
|
||||||
|
vbytes = b'</s>'
|
||||||
|
tt = 3
|
||||||
|
elif len(vbytes) == 0:
|
||||||
tt = 3 # Control
|
tt = 3 # Control
|
||||||
elif tokid >= 3 and tokid <= 258 and len(vbytes) == 1:
|
elif tokid >= 3 and tokid <= 258 and len(vbytes) == 1:
|
||||||
hv = hex(vbytes[0])[2:].upper()
|
vbytes = bytes(f'<0x{vbytes[0]:02X}>', encoding = 'UTF-8')
|
||||||
vbytes = bytes(f'<0x{hv}>', encoding = 'UTF-8')
|
|
||||||
tt = 6 # Byte
|
tt = 6 # Byte
|
||||||
else:
|
else:
|
||||||
vbytes = vbytes.replace(b' ', b'\xe2\x96\x81')
|
vbytes = vbytes.replace(b' ', b'\xe2\x96\x81')
|
||||||
|
@ -247,6 +255,9 @@ class GGMLToGGUF:
|
||||||
gguf_writer.add_token_list(tokens)
|
gguf_writer.add_token_list(tokens)
|
||||||
gguf_writer.add_token_scores(scores)
|
gguf_writer.add_token_scores(scores)
|
||||||
gguf_writer.add_token_types(toktypes)
|
gguf_writer.add_token_types(toktypes)
|
||||||
|
gguf_writer.add_unk_token_id(0)
|
||||||
|
gguf_writer.add_bos_token_id(1)
|
||||||
|
gguf_writer.add_eos_token_id(2)
|
||||||
|
|
||||||
def add_tensors(self, gguf_writer):
|
def add_tensors(self, gguf_writer):
|
||||||
nm = self.name_map
|
nm = self.name_map
|
||||||
|
@ -316,7 +327,11 @@ def main():
|
||||||
data = np.memmap(cfg.input, mode = 'r')
|
data = np.memmap(cfg.input, mode = 'r')
|
||||||
model = GGMLV3Model()
|
model = GGMLV3Model()
|
||||||
print('* Scanning GGML input file')
|
print('* Scanning GGML input file')
|
||||||
|
try:
|
||||||
offset = model.load(data, 0)
|
offset = model.load(data, 0)
|
||||||
|
except OverflowError:
|
||||||
|
print(f'!!! Caught overflow loading tensors. The most likely issue is running on Windows but not in WSL. Try running in WSL if possible.', file = sys.stderr)
|
||||||
|
raise
|
||||||
print(f'* GGML model hyperparameters: {model.hyperparameters}')
|
print(f'* GGML model hyperparameters: {model.hyperparameters}')
|
||||||
vocab_override = None
|
vocab_override = None
|
||||||
params_override = None
|
params_override = None
|
||||||
|
@ -331,4 +346,5 @@ def main():
|
||||||
converter.save()
|
converter.save()
|
||||||
print(f'* Successful completion. Output saved to: {cfg.output}')
|
print(f'* Successful completion. Output saved to: {cfg.output}')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
27
convert.py
27
convert.py
|
@ -69,7 +69,10 @@ SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
||||||
'I32': DT_I32,
|
'I32': DT_I32,
|
||||||
}
|
}
|
||||||
|
|
||||||
class GGMLFileType(enum.Enum):
|
# TODO: match this with `llama_ftype`
|
||||||
|
# TODO: rename to LLAMAFileType
|
||||||
|
# TODO: move to `gguf.py`
|
||||||
|
class GGMLFileType(enum.IntEnum):
|
||||||
AllF32 = 0
|
AllF32 = 0
|
||||||
MostlyF16 = 1 # except 1d tensors
|
MostlyF16 = 1 # except 1d tensors
|
||||||
|
|
||||||
|
@ -101,6 +104,8 @@ class Params:
|
||||||
n_head_kv: int
|
n_head_kv: int
|
||||||
f_norm_eps: float
|
f_norm_eps: float
|
||||||
|
|
||||||
|
ftype: Optional[GGMLFileType] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_n_mult(n_ff: int, n_embd: int) -> int:
|
def find_n_mult(n_ff: int, n_embd: int) -> int:
|
||||||
# hardcoded magic range
|
# hardcoded magic range
|
||||||
|
@ -738,6 +743,9 @@ class OutputFile:
|
||||||
self.gguf.add_head_count_kv (params.n_head_kv)
|
self.gguf.add_head_count_kv (params.n_head_kv)
|
||||||
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
|
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
|
||||||
|
|
||||||
|
if params.ftype:
|
||||||
|
self.gguf.add_file_type(params.ftype)
|
||||||
|
|
||||||
def add_meta_vocab(self, vocab: Vocab) -> None:
|
def add_meta_vocab(self, vocab: Vocab) -> None:
|
||||||
tokens = []
|
tokens = []
|
||||||
scores = []
|
scores = []
|
||||||
|
@ -956,7 +964,7 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, Sentence
|
||||||
path = path3
|
path = path3
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Could not find tokenizer.model in {path} or its parent; "
|
f"Could not find {vocab_file} in {path} or its parent; "
|
||||||
"if it's in another directory, pass the directory as --vocab-dir")
|
"if it's in another directory, pass the directory as --vocab-dir")
|
||||||
|
|
||||||
print(f"Loading vocab file '{path}', type '{vocabtype}'")
|
print(f"Loading vocab file '{path}', type '{vocabtype}'")
|
||||||
|
@ -1020,6 +1028,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
||||||
" - LLaMA v2: --ctx 4096\n")
|
" - LLaMA v2: --ctx 4096\n")
|
||||||
params.n_ctx = args.ctx
|
params.n_ctx = args.ctx
|
||||||
|
|
||||||
|
if args.outtype:
|
||||||
|
params.ftype = {
|
||||||
|
"f32": GGMLFileType.AllF32,
|
||||||
|
"f16": GGMLFileType.MostlyF16,
|
||||||
|
}[args.outtype]
|
||||||
|
|
||||||
print(f"params = {params}")
|
print(f"params = {params}")
|
||||||
|
|
||||||
vocab: Vocab
|
vocab: Vocab
|
||||||
|
@ -1042,9 +1056,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
||||||
|
|
||||||
model = model_plus.model
|
model = model_plus.model
|
||||||
model = convert_model_names(model, params)
|
model = convert_model_names(model, params)
|
||||||
output_type = pick_output_type(model, args.outtype)
|
ftype = pick_output_type(model, args.outtype)
|
||||||
model = convert_to_output_type(model, output_type)
|
model = convert_to_output_type(model, ftype)
|
||||||
outfile = args.outfile or default_outfile(model_plus.paths, output_type)
|
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
|
||||||
|
|
||||||
|
params.ftype = ftype
|
||||||
|
print(f"Writing {outfile}, format {ftype}")
|
||||||
|
|
||||||
OutputFile.write_all(outfile, params, model, vocab)
|
OutputFile.write_all(outfile, params, model, vocab)
|
||||||
print(f"Wrote {outfile}")
|
print(f"Wrote {outfile}")
|
||||||
|
|
|
@ -72,12 +72,20 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.embedding){
|
if (embd_inp.size() > (size_t)params.n_ctx) {
|
||||||
if (embd_inp.size() > 0) {
|
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
|
||||||
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) {
|
__func__, embd_inp.size(), params.n_ctx);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!embd_inp.empty()) {
|
||||||
|
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
||||||
|
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
n_past += n_tokens;
|
||||||
|
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_embd = llama_n_embd(ctx);
|
const int n_embd = llama_n_embd(ctx);
|
||||||
|
@ -87,7 +95,6 @@ int main(int argc, char ** argv) {
|
||||||
printf("%f ", embeddings[i]);
|
printf("%f ", embeddings[i]);
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
|
|
@ -148,7 +148,7 @@ struct cmd_params {
|
||||||
};
|
};
|
||||||
|
|
||||||
static const cmd_params cmd_params_defaults = {
|
static const cmd_params cmd_params_defaults = {
|
||||||
/* model */ {"models/7B/ggml-model-q4_0.bin"},
|
/* model */ {"models/7B/ggml-model-q4_0.gguf"},
|
||||||
/* n_prompt */ {512},
|
/* n_prompt */ {512},
|
||||||
/* n_gen */ {128},
|
/* n_gen */ {128},
|
||||||
/* n_batch */ {512},
|
/* n_batch */ {512},
|
||||||
|
@ -179,12 +179,12 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||||
fprintf(stdout, " -mg i, --main-gpu <n> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
fprintf(stdout, " -mg i, --main-gpu <n> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
||||||
fprintf(stdout, " -lv, --low-vram <0|1> (default: %s)\n", join(cmd_params_defaults.low_vram, ",").c_str());
|
fprintf(stdout, " -lv, --low-vram <0|1> (default: %s)\n", join(cmd_params_defaults.low_vram, ",").c_str());
|
||||||
fprintf(stdout, " -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
|
fprintf(stdout, " -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
|
||||||
fprintf(stdout, " -ts, --tensor_split <ts> \n");
|
fprintf(stdout, " -ts, --tensor_split <ts0/ts1/..> \n");
|
||||||
fprintf(stdout, " -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
|
fprintf(stdout, " -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
|
||||||
fprintf(stdout, " -o, --output <csv|json|md|sql> (default: %s)\n", cmd_params_defaults.output_format == CSV ? "csv" : cmd_params_defaults.output_format == JSON ? "json" : "md");
|
fprintf(stdout, " -o, --output <csv|json|md|sql> (default: %s)\n", cmd_params_defaults.output_format == CSV ? "csv" : cmd_params_defaults.output_format == JSON ? "json" : cmd_params_defaults.output_format == MARKDOWN ? "md" : "sql");
|
||||||
fprintf(stdout, " -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
|
fprintf(stdout, " -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
|
||||||
fprintf(stdout, "\n");
|
fprintf(stdout, "\n");
|
||||||
fprintf(stdout, "Multiple values can be given for each parameter by separating them with ',' or by repeating the parameter.\n");
|
fprintf(stdout, "Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -728,7 +728,7 @@ struct markdown_printer : public printer {
|
||||||
if (!is_cpu_backend) {
|
if (!is_cpu_backend) {
|
||||||
fields.push_back("n_gpu_layers");
|
fields.push_back("n_gpu_layers");
|
||||||
}
|
}
|
||||||
if (params.n_batch.size() > 1 || params.n_threads != cmd_params_defaults.n_threads || is_cpu_backend) {
|
if (params.n_threads.size() > 1 || params.n_threads != cmd_params_defaults.n_threads || is_cpu_backend) {
|
||||||
fields.push_back("n_threads");
|
fields.push_back("n_threads");
|
||||||
}
|
}
|
||||||
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
|
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
|
||||||
|
|
|
@ -288,6 +288,10 @@ These options help improve the performance and memory usage of the LLaMA models.
|
||||||
|
|
||||||
- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. **Note**: Restoring a cached prompt does not imply restoring the exact state of the session at the point it was saved. So even when specifying a specific seed, you are not guaranteed to get the same sequence of tokens as the original generation.
|
- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. **Note**: Restoring a cached prompt does not imply restoring the exact state of the session at the point it was saved. So even when specifying a specific seed, you are not guaranteed to get the same sequence of tokens as the original generation.
|
||||||
|
|
||||||
|
### Grammars
|
||||||
|
|
||||||
|
- `--grammar GRAMMAR`, `--grammar-file FILE`: Specify a grammar (defined inline or in a file) to constrain model output to a specific format. For example, you could force the model to output JSON or to speak only in emojis. See the [GBNF guide](../../grammars/README.md) for details on the syntax.
|
||||||
|
|
||||||
### Quantization
|
### Quantization
|
||||||
|
|
||||||
For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-data--run).
|
For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-data--run).
|
||||||
|
|
|
@ -719,12 +719,11 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
fprintf(stdout, " number of layers to store in VRAM\n");
|
fprintf(stdout, " number of layers to store in VRAM\n");
|
||||||
fprintf(stdout, " -ts SPLIT --tensor-split SPLIT\n");
|
fprintf(stdout, " -ts SPLIT --tensor-split SPLIT\n");
|
||||||
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
||||||
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
|
||||||
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
|
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
|
||||||
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
|
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
|
||||||
fprintf(stdout, " -mmq, --mul-mat-q use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" );
|
fprintf(stdout, " -nommq, --no-mul-mat-q\n");
|
||||||
fprintf(stdout, " Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" );
|
fprintf(stdout, " use cuBLAS instead of custom mul_mat_q CUDA kernels.\n");
|
||||||
fprintf(stdout, " is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" );
|
fprintf(stdout, " Not recommended since this is both slower and uses more VRAM.\n");
|
||||||
#endif
|
#endif
|
||||||
fprintf(stdout, " -m FNAME, --model FNAME\n");
|
fprintf(stdout, " -m FNAME, --model FNAME\n");
|
||||||
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
|
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
|
||||||
|
@ -915,12 +914,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {});
|
LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {});
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
}
|
}
|
||||||
else if (arg == "--mul-mat-q" || arg == "-mmq")
|
else if (arg == "--no-mul-mat-q" || arg == "-nommq")
|
||||||
{
|
{
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
params.mul_mat_q = true;
|
params.mul_mat_q = false;
|
||||||
#else
|
#else
|
||||||
LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n", {});
|
LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n", {});
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
}
|
}
|
||||||
else if (arg == "--main-gpu" || arg == "-mg")
|
else if (arg == "--main-gpu" || arg == "-mg")
|
||||||
|
@ -1104,29 +1103,38 @@ static json format_tokenizer_response(const std::vector<llama_token> &tokens)
|
||||||
{"tokens", tokens}};
|
{"tokens", tokens}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static T json_value(const json &body, const std::string &key, const T &default_value)
|
||||||
|
{
|
||||||
|
// Fallback null to default value
|
||||||
|
return body.contains(key) && !body.at(key).is_null()
|
||||||
|
? body.value(key, default_value)
|
||||||
|
: default_value;
|
||||||
|
}
|
||||||
|
|
||||||
static void parse_options_completion(const json &body, llama_server_context &llama)
|
static void parse_options_completion(const json &body, llama_server_context &llama)
|
||||||
{
|
{
|
||||||
gpt_params default_params;
|
gpt_params default_params;
|
||||||
|
|
||||||
llama.stream = body.value("stream", false);
|
llama.stream = json_value(body, "stream", false);
|
||||||
llama.params.n_predict = body.value("n_predict", default_params.n_predict);
|
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
||||||
llama.params.top_k = body.value("top_k", default_params.top_k);
|
llama.params.top_k = json_value(body, "top_k", default_params.top_k);
|
||||||
llama.params.top_p = body.value("top_p", default_params.top_p);
|
llama.params.top_p = json_value(body, "top_p", default_params.top_p);
|
||||||
llama.params.tfs_z = body.value("tfs_z", default_params.tfs_z);
|
llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z);
|
||||||
llama.params.typical_p = body.value("typical_p", default_params.typical_p);
|
llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p);
|
||||||
llama.params.repeat_last_n = body.value("repeat_last_n", default_params.repeat_last_n);
|
llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n);
|
||||||
llama.params.temp = body.value("temperature", default_params.temp);
|
llama.params.temp = json_value(body, "temperature", default_params.temp);
|
||||||
llama.params.repeat_penalty = body.value("repeat_penalty", default_params.repeat_penalty);
|
llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty);
|
||||||
llama.params.presence_penalty = body.value("presence_penalty", default_params.presence_penalty);
|
llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty);
|
||||||
llama.params.frequency_penalty = body.value("frequency_penalty", default_params.frequency_penalty);
|
llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty);
|
||||||
llama.params.mirostat = body.value("mirostat", default_params.mirostat);
|
llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat);
|
||||||
llama.params.mirostat_tau = body.value("mirostat_tau", default_params.mirostat_tau);
|
llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau);
|
||||||
llama.params.mirostat_eta = body.value("mirostat_eta", default_params.mirostat_eta);
|
llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta);
|
||||||
llama.params.penalize_nl = body.value("penalize_nl", default_params.penalize_nl);
|
llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
|
||||||
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
|
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
||||||
llama.params.seed = body.value("seed", default_params.seed);
|
llama.params.seed = json_value(body, "seed", default_params.seed);
|
||||||
llama.params.grammar = body.value("grammar", default_params.grammar);
|
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
|
||||||
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
|
llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
|
||||||
|
|
||||||
if (body.count("prompt") != 0)
|
if (body.count("prompt") != 0)
|
||||||
{
|
{
|
||||||
|
@ -1138,7 +1146,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
||||||
}
|
}
|
||||||
|
|
||||||
llama.params.logit_bias.clear();
|
llama.params.logit_bias.clear();
|
||||||
if (body.value("ignore_eos", false))
|
if (json_value(body, "ignore_eos", false))
|
||||||
{
|
{
|
||||||
llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
|
llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
|
||||||
}
|
}
|
||||||
|
@ -1445,7 +1453,7 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
if (res.status == 400) {
|
if (res.status == 400) {
|
||||||
res.set_content("Invalid request", "text/plain");
|
res.set_content("Invalid request", "text/plain");
|
||||||
} else {
|
} else if (res.status != 500) {
|
||||||
res.set_content("File Not Found", "text/plain");
|
res.set_content("File Not Found", "text/plain");
|
||||||
res.status = 404;
|
res.status = 404;
|
||||||
} });
|
} });
|
||||||
|
|
|
@ -1868,10 +1868,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
|
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
|
||||||
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
|
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
|
||||||
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
|
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
|
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
|
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
|
||||||
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
|
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
|
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
|
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
|
||||||
t04->grad = expand(gb, ggml_add_inplace(ctx0,
|
t04->grad = expand(gb, ggml_add_inplace(ctx0,
|
||||||
ggml_add_inplace(ctx0,
|
ggml_add_inplace(ctx0,
|
||||||
|
|
|
@ -76,7 +76,7 @@ struct ggml_allocr {
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef GGML_ALLOCATOR_DEBUG
|
#ifdef GGML_ALLOCATOR_DEBUG
|
||||||
static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
|
static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||||
for (int i = 0; i < 1024; i++) {
|
for (int i = 0; i < 1024; i++) {
|
||||||
if (alloc->allocated_tensors[i] == NULL) {
|
if (alloc->allocated_tensors[i] == NULL) {
|
||||||
alloc->allocated_tensors[i] = tensor;
|
alloc->allocated_tensors[i] = tensor;
|
||||||
|
@ -85,7 +85,7 @@ static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tens
|
||||||
}
|
}
|
||||||
GGML_ASSERT(!"out of allocated_tensors");
|
GGML_ASSERT(!"out of allocated_tensors");
|
||||||
}
|
}
|
||||||
static void remove_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
|
static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||||
for (int i = 0; i < 1024; i++) {
|
for (int i = 0; i < 1024; i++) {
|
||||||
if (alloc->allocated_tensors[i] == tensor ||
|
if (alloc->allocated_tensors[i] == tensor ||
|
||||||
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
|
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
|
||||||
|
|
187
ggml-cuda.cu
187
ggml-cuda.cu
|
@ -259,6 +259,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||||
#define CUDA_CPY_BLOCK_SIZE 32
|
#define CUDA_CPY_BLOCK_SIZE 32
|
||||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||||
#define CUDA_ROPE_BLOCK_SIZE 256
|
#define CUDA_ROPE_BLOCK_SIZE 256
|
||||||
|
#define CUDA_ALIBI_BLOCK_SIZE 32
|
||||||
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||||
|
@ -286,7 +287,7 @@ static int g_device_count = -1;
|
||||||
static int g_main_device = 0;
|
static int g_main_device = 0;
|
||||||
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
|
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
|
||||||
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
static bool g_mul_mat_q = false;
|
static bool g_mul_mat_q = true;
|
||||||
|
|
||||||
static void * g_scratch_buffer = nullptr;
|
static void * g_scratch_buffer = nullptr;
|
||||||
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
|
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
|
||||||
|
@ -3886,13 +3887,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||||
// rope == RoPE == rotary positional embedding
|
// rope == RoPE == rotary positional embedding
|
||||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
|
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
|
||||||
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
||||||
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
|
||||||
if (col >= ncols) {
|
if (col >= ncols) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
const int i = row*ncols + col;
|
const int i = row*ncols + col;
|
||||||
|
|
||||||
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
|
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
|
||||||
|
@ -3940,9 +3941,32 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
||||||
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
|
static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
|
||||||
|
const int n_heads_log2_floor, const float m0, const float m1) {
|
||||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (col >= ncols) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||||
|
const int i = row*ncols + col;
|
||||||
|
|
||||||
|
const int k = row/k_rows;
|
||||||
|
|
||||||
|
float m_k;
|
||||||
|
if (k < n_heads_log2_floor) {
|
||||||
|
m_k = powf(m0, k + 1);
|
||||||
|
} else {
|
||||||
|
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[i] = col * m_k + x[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
|
||||||
|
const int col = blockDim.y*blockIdx.y + threadIdx.y;
|
||||||
|
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (col >= ncols) {
|
if (col >= ncols) {
|
||||||
return;
|
return;
|
||||||
|
@ -3955,24 +3979,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
||||||
|
|
||||||
// the CUDA soft max implementation differs from the CPU implementation
|
// the CUDA soft max implementation differs from the CPU implementation
|
||||||
// instead of doubles floats are used
|
// instead of doubles floats are used
|
||||||
// values are also not normalized to the maximum value by subtracting it in the exponential function
|
|
||||||
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
|
|
||||||
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
||||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
const int block_size = blockDim.x;
|
const int block_size = blockDim.y;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.y;
|
||||||
|
|
||||||
float tmp = 0.0;
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
for (int block_start = 0; block_start < ncols; block_start += block_size) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const int col = block_start + tid;
|
const int i = row*ncols + col;
|
||||||
|
max_val = max(max_val, x[i]);
|
||||||
if (col >= ncols) {
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// find the max value in the block
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
|
||||||
|
}
|
||||||
|
|
||||||
|
float tmp = 0.f;
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const int i = row*ncols + col;
|
const int i = row*ncols + col;
|
||||||
const float val = expf(x[i]);
|
const float val = expf(x[i] - max_val);
|
||||||
tmp += val;
|
tmp += val;
|
||||||
dst[i] = val;
|
dst[i] = val;
|
||||||
}
|
}
|
||||||
|
@ -3983,15 +4012,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int block_start = 0; block_start < ncols; block_start += block_size) {
|
const float inv_tmp = 1.f / tmp;
|
||||||
const int col = block_start + tid;
|
|
||||||
|
|
||||||
if (col >= ncols) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const int i = row*ncols + col;
|
const int i = row*ncols + col;
|
||||||
dst[i] /= tmp;
|
dst[i] *= inv_tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4752,9 +4777,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
||||||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||||
GGML_ASSERT(nrows % 2 == 0);
|
GGML_ASSERT(nrows % 2 == 0);
|
||||||
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4766,16 +4791,25 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
|
||||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
|
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
||||||
|
const int k_rows, const int n_heads_log2_floor, const float m0,
|
||||||
|
const float m1, cudaStream_t stream) {
|
||||||
|
const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
|
||||||
|
const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
|
||||||
|
const dim3 block_nums(num_blocks_x, nrows, 1);
|
||||||
|
alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
|
||||||
|
}
|
||||||
|
|
||||||
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
|
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
|
||||||
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
|
const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
|
||||||
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
|
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
|
||||||
const dim3 block_nums(block_num_x, nrows_x, 1);
|
const dim3 block_nums(nrows_x, block_num_x, 1);
|
||||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const dim3 block_dims(1, WARP_SIZE, 1);
|
||||||
const dim3 block_nums(1, nrows_x, 1);
|
const dim3 block_nums(nrows_x, 1, 1);
|
||||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5501,6 +5535,41 @@ inline void ggml_cuda_op_rope(
|
||||||
(void) i1;
|
(void) i1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void ggml_cuda_op_alibi(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||||
|
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||||
|
cudaStream_t & cudaStream_main){
|
||||||
|
|
||||||
|
GGML_ASSERT(src0_ddf_i != nullptr);
|
||||||
|
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||||
|
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t ne01 = src0->ne[1];
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t i01_diff = i01_high - i01_low;
|
||||||
|
|
||||||
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||||
|
float max_bias;
|
||||||
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
GGML_ASSERT(ne01 + n_past == ne00);
|
||||||
|
GGML_ASSERT(n_head == ne02);
|
||||||
|
|
||||||
|
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||||
|
|
||||||
|
// compute
|
||||||
|
alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) src0_ddq_i;
|
||||||
|
(void) src1_ddf_i;
|
||||||
|
(void) i1;
|
||||||
|
}
|
||||||
|
|
||||||
inline void ggml_cuda_op_diag_mask_inf(
|
inline void ggml_cuda_op_diag_mask_inf(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||||
|
@ -6121,6 +6190,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||||
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
(void) src0;
|
(void) src0;
|
||||||
(void) src1;
|
(void) src1;
|
||||||
|
@ -6240,7 +6314,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
||||||
return extra;
|
return extra;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
|
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
|
||||||
if (scratch && g_scratch_size == 0) {
|
if (scratch && g_scratch_size == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -6249,14 +6323,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
|
||||||
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
|
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
|
||||||
const ggml_op src0_op = tensor->src[0]->op;
|
const ggml_op src0_op = tensor->src[0]->op;
|
||||||
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
|
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
|
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
|
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
|
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor->backend = GGML_BACKEND_GPU;
|
tensor->backend = GGML_BACKEND_GPU;
|
||||||
|
|
||||||
|
if (scratch && no_alloc) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor_extra_gpu * extra;
|
struct ggml_tensor_extra_gpu * extra;
|
||||||
|
|
||||||
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
|
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
|
||||||
|
@ -6308,16 +6387,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
|
||||||
tensor->extra = extra;
|
tensor->extra = extra;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) {
|
||||||
|
if (g_scratch_size == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (g_scratch_buffer == nullptr) {
|
||||||
|
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
|
||||||
|
|
||||||
|
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
|
||||||
|
tensor->op == GGML_OP_VIEW;
|
||||||
|
|
||||||
|
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
|
||||||
|
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
|
||||||
|
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
||||||
|
size_t view_offset = 0;
|
||||||
|
if (tensor->op == GGML_OP_VIEW) {
|
||||||
|
memcpy(&view_offset, tensor->op_params, sizeof(size_t));
|
||||||
|
}
|
||||||
|
extra->data_device[g_main_device] = src0_ddc + view_offset;
|
||||||
|
} else {
|
||||||
|
extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor->extra = extra;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor, true, false);
|
ggml_cuda_assign_buffers_impl(tensor, true, false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) {
|
||||||
|
ggml_cuda_assign_buffers_impl(tensor, true, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
|
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor, false, false);
|
ggml_cuda_assign_buffers_impl(tensor, false, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
|
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor, false, true);
|
ggml_cuda_assign_buffers_impl(tensor, false, true, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_set_main_device(int main_device) {
|
void ggml_cuda_set_main_device(int main_device) {
|
||||||
|
@ -6456,6 +6567,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||||
}
|
}
|
||||||
func = ggml_cuda_rope;
|
func = ggml_cuda_rope;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_ALIBI:
|
||||||
|
if (!any_on_device) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
func = ggml_cuda_alibi;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
|
||||||
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
||||||
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
||||||
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
||||||
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
||||||
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
|
||||||
|
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
|
||||||
|
|
||||||
GGML_API void ggml_cuda_set_main_device(int main_device);
|
GGML_API void ggml_cuda_set_main_device(int main_device);
|
||||||
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
||||||
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
||||||
|
|
|
@ -1850,6 +1850,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||||
//load data and store to threadgroup memory
|
//load data and store to threadgroup memory
|
||||||
half4x4 temp_a;
|
half4x4 temp_a;
|
||||||
dequantize_func(x, il, temp_a);
|
dequantize_func(x, il, temp_a);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
#pragma unroll(16)
|
#pragma unroll(16)
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i++) {
|
||||||
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
||||||
|
@ -1895,14 +1896,14 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
||||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
threadgroup_barrier(mem_flags::mem_device);
|
|
||||||
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_device);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||||||
if (sgitg==0) {
|
if (sgitg==0) {
|
||||||
for (int i = 0; i < n_rows; i++) {
|
for (int i = 0; i < n_rows; i++) {
|
||||||
|
|
120
ggml.h
120
ggml.h
|
@ -211,6 +211,7 @@
|
||||||
#define GGML_MAX_OP_PARAMS 32
|
#define GGML_MAX_OP_PARAMS 32
|
||||||
#define GGML_DEFAULT_N_THREADS 4
|
#define GGML_DEFAULT_N_THREADS 4
|
||||||
|
|
||||||
|
|
||||||
#define GGML_EXIT_SUCCESS 0
|
#define GGML_EXIT_SUCCESS 0
|
||||||
#define GGML_EXIT_ABORTED 1
|
#define GGML_EXIT_ABORTED 1
|
||||||
|
|
||||||
|
@ -259,8 +260,9 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#if defined(__ARM_NEON) && defined(__CUDACC__)
|
||||||
// we use the built-in 16-bit float type
|
typedef half ggml_fp16_t;
|
||||||
|
#elif defined(__ARM_NEON)
|
||||||
typedef __fp16 ggml_fp16_t;
|
typedef __fp16 ggml_fp16_t;
|
||||||
#else
|
#else
|
||||||
typedef uint16_t ggml_fp16_t;
|
typedef uint16_t ggml_fp16_t;
|
||||||
|
@ -344,10 +346,12 @@ extern "C" {
|
||||||
GGML_OP_ARGMAX,
|
GGML_OP_ARGMAX,
|
||||||
GGML_OP_REPEAT,
|
GGML_OP_REPEAT,
|
||||||
GGML_OP_REPEAT_BACK,
|
GGML_OP_REPEAT_BACK,
|
||||||
|
GGML_OP_CONCAT,
|
||||||
GGML_OP_SILU_BACK,
|
GGML_OP_SILU_BACK,
|
||||||
GGML_OP_NORM, // normalize
|
GGML_OP_NORM, // normalize
|
||||||
GGML_OP_RMS_NORM,
|
GGML_OP_RMS_NORM,
|
||||||
GGML_OP_RMS_NORM_BACK,
|
GGML_OP_RMS_NORM_BACK,
|
||||||
|
GGML_OP_GROUP_NORM,
|
||||||
|
|
||||||
GGML_OP_MUL_MAT,
|
GGML_OP_MUL_MAT,
|
||||||
GGML_OP_OUT_PROD,
|
GGML_OP_OUT_PROD,
|
||||||
|
@ -373,14 +377,19 @@ extern "C" {
|
||||||
GGML_OP_CLAMP,
|
GGML_OP_CLAMP,
|
||||||
GGML_OP_CONV_1D,
|
GGML_OP_CONV_1D,
|
||||||
GGML_OP_CONV_2D,
|
GGML_OP_CONV_2D,
|
||||||
|
GGML_OP_CONV_TRANSPOSE_2D,
|
||||||
GGML_OP_POOL_1D,
|
GGML_OP_POOL_1D,
|
||||||
GGML_OP_POOL_2D,
|
GGML_OP_POOL_2D,
|
||||||
|
|
||||||
|
GGML_OP_UPSCALE, // nearest interpolate
|
||||||
|
|
||||||
GGML_OP_FLASH_ATTN,
|
GGML_OP_FLASH_ATTN,
|
||||||
GGML_OP_FLASH_FF,
|
GGML_OP_FLASH_FF,
|
||||||
GGML_OP_FLASH_ATTN_BACK,
|
GGML_OP_FLASH_ATTN_BACK,
|
||||||
GGML_OP_WIN_PART,
|
GGML_OP_WIN_PART,
|
||||||
GGML_OP_WIN_UNPART,
|
GGML_OP_WIN_UNPART,
|
||||||
|
GGML_OP_GET_REL_POS,
|
||||||
|
GGML_OP_ADD_REL_POS,
|
||||||
|
|
||||||
GGML_OP_UNARY,
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
|
@ -804,6 +813,13 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// concat a and b on dim 2
|
||||||
|
// used in stable-diffusion
|
||||||
|
GGML_API struct ggml_tensor * ggml_concat(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_abs(
|
GGML_API struct ggml_tensor * ggml_abs(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
@ -912,6 +928,19 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
float eps);
|
float eps);
|
||||||
|
|
||||||
|
// group normalize along ne0*ne1*n_groups
|
||||||
|
// used in stable-diffusion
|
||||||
|
// TODO: eps is hardcoded to 1e-6 for now
|
||||||
|
GGML_API struct ggml_tensor * ggml_group_norm(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_groups);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_group_norm_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_groups);
|
||||||
|
|
||||||
// a - x
|
// a - x
|
||||||
// b - dy
|
// b - dy
|
||||||
// TODO: update with configurable eps
|
// TODO: update with configurable eps
|
||||||
|
@ -1212,6 +1241,15 @@ extern "C" {
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale);
|
float freq_scale);
|
||||||
|
|
||||||
|
// xPos RoPE, in-place, returns view(a)
|
||||||
|
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_past,
|
||||||
|
int n_dims,
|
||||||
|
float base,
|
||||||
|
bool down);
|
||||||
|
|
||||||
// rotary position embedding backward, i.e compute dx from dy
|
// rotary position embedding backward, i.e compute dx from dy
|
||||||
// a - dy
|
// a - dy
|
||||||
GGML_API struct ggml_tensor * ggml_rope_back(
|
GGML_API struct ggml_tensor * ggml_rope_back(
|
||||||
|
@ -1220,7 +1258,11 @@ extern "C" {
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx);
|
int n_ctx,
|
||||||
|
float freq_base,
|
||||||
|
float freq_scale,
|
||||||
|
float xpos_base,
|
||||||
|
bool xpos_down);
|
||||||
|
|
||||||
// alibi position embedding
|
// alibi position embedding
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
|
@ -1247,6 +1289,15 @@ extern "C" {
|
||||||
int p0, // padding
|
int p0, // padding
|
||||||
int d0); // dilation
|
int d0); // dilation
|
||||||
|
|
||||||
|
// conv_1d with padding = half
|
||||||
|
// alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
|
||||||
|
GGML_API struct ggml_tensor* ggml_conv_1d_ph(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
int s,
|
||||||
|
int d);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_conv_2d(
|
GGML_API struct ggml_tensor * ggml_conv_2d(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -1258,14 +1309,38 @@ extern "C" {
|
||||||
int d0,
|
int d0,
|
||||||
int d1);
|
int d1);
|
||||||
|
|
||||||
// conv_1d with padding = half
|
|
||||||
// alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
|
// kernel size is a->ne[0] x a->ne[1]
|
||||||
GGML_API struct ggml_tensor * ggml_conv_1d_ph(
|
// stride is equal to kernel size
|
||||||
|
// padding is zero
|
||||||
|
// example:
|
||||||
|
// a: 16 16 3 768
|
||||||
|
// b: 1024 1024 3 1
|
||||||
|
// res: 64 64 768 1
|
||||||
|
// used in sam
|
||||||
|
GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// kernel size is a->ne[0] x a->ne[1]
|
||||||
|
// stride is 1
|
||||||
|
// padding is half
|
||||||
|
// example:
|
||||||
|
// a: 3 3 256 256
|
||||||
|
// b: 64 64 256 1
|
||||||
|
// res: 64 64 256 1
|
||||||
|
// used in sam
|
||||||
|
GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int s,
|
int stride);
|
||||||
int d);
|
|
||||||
|
|
||||||
enum ggml_op_pool {
|
enum ggml_op_pool {
|
||||||
GGML_OP_POOL_MAX,
|
GGML_OP_POOL_MAX,
|
||||||
|
@ -1292,6 +1367,13 @@ extern "C" {
|
||||||
int p0,
|
int p0,
|
||||||
int p1);
|
int p1);
|
||||||
|
|
||||||
|
// nearest interpolate
|
||||||
|
// used in stable-diffusion
|
||||||
|
GGML_API struct ggml_tensor * ggml_upscale(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int scale_factor);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn(
|
GGML_API struct ggml_tensor * ggml_flash_attn(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * q,
|
struct ggml_tensor * q,
|
||||||
|
@ -1345,6 +1427,27 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_unary_op op);
|
enum ggml_unary_op op);
|
||||||
|
|
||||||
|
// used in sam
|
||||||
|
GGML_API struct ggml_tensor * ggml_get_rel_pos(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int qh,
|
||||||
|
int kh);
|
||||||
|
|
||||||
|
// used in sam
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_add_rel_pos(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * pw,
|
||||||
|
struct ggml_tensor * ph);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * pw,
|
||||||
|
struct ggml_tensor * ph);
|
||||||
|
|
||||||
// custom operators
|
// custom operators
|
||||||
|
|
||||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||||
|
@ -1499,6 +1602,7 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * tensor);
|
struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
|
||||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
||||||
|
|
4
gguf.py
4
gguf.py
|
@ -26,6 +26,7 @@ KEY_GENERAL_DESCRIPTION = "general.description"
|
||||||
KEY_GENERAL_LICENSE = "general.license"
|
KEY_GENERAL_LICENSE = "general.license"
|
||||||
KEY_GENERAL_SOURCE_URL = "general.source.url"
|
KEY_GENERAL_SOURCE_URL = "general.source.url"
|
||||||
KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
|
KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
|
||||||
|
KEY_GENERAL_FILE_TYPE = "general.file_type"
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
KEY_LLM_CONTEXT_LENGTH = "{arch}.context_length"
|
KEY_LLM_CONTEXT_LENGTH = "{arch}.context_length"
|
||||||
|
@ -595,6 +596,9 @@ class GGUFWriter:
|
||||||
def add_source_hf_repo(self, repo: str):
|
def add_source_hf_repo(self, repo: str):
|
||||||
self.add_string(KEY_GENERAL_SOURCE_HF_REPO, repo)
|
self.add_string(KEY_GENERAL_SOURCE_HF_REPO, repo)
|
||||||
|
|
||||||
|
def add_file_type(self, ftype: int):
|
||||||
|
self.add_uint32(KEY_GENERAL_FILE_TYPE, ftype)
|
||||||
|
|
||||||
def add_name(self, name: str):
|
def add_name(self, name: str):
|
||||||
self.add_string(KEY_GENERAL_NAME, name)
|
self.add_string(KEY_GENERAL_NAME, name)
|
||||||
|
|
||||||
|
|
91
grammars/README.md
Normal file
91
grammars/README.md
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
# GBNF Guide
|
||||||
|
|
||||||
|
GBNF (GGML BNF) is a format for defining [formal grammars](https://en.wikipedia.org/wiki/Formal_grammar) to constrain model outputs in `llama.cpp`. For example, you can use it to force the model to generate valid JSON, or speak only in emojis. GBNF grammars are supported in various ways in `examples/main` and `examples/server`.
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
[Bakus-Naur Form (BNF)](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form) is a notation for describing the syntax of formal languages like programming languages, file formats, and protocols. GBNF is an extension of BNF that primarily adds a few modern regex-like features.
|
||||||
|
|
||||||
|
## Basics
|
||||||
|
|
||||||
|
In GBNF, we define *production rules* that specify how a *non-terminal* (rule name) can be replaced with sequences of *terminals* (characters, specifically Unicode [code points](https://en.wikipedia.org/wiki/Code_point)) and other non-terminals. The basic format of a production rule is `nonterminal ::= sequence...`.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
Before going deeper, let's look at some of the features demonstrated in `grammars/chess.gbnf`, a small chess notation grammar:
|
||||||
|
```
|
||||||
|
# `root` specifies the pattern for the overall output
|
||||||
|
root ::= (
|
||||||
|
# it must start with the characters "1. " followed by a sequence
|
||||||
|
# of characters that match the `move` rule, followed by a space, followed
|
||||||
|
# by another move, and then a newline
|
||||||
|
"1. " move " " move "\n"
|
||||||
|
|
||||||
|
# it's followed by one or more subsequent moves, numbered with one or two digits
|
||||||
|
([1-9] [0-9]? ". " move " " move "\n")+
|
||||||
|
)
|
||||||
|
|
||||||
|
# `move` is an abstract representation, which can be a pawn, nonpawn, or castle.
|
||||||
|
# The `[+#]?` denotes the possibility of checking or mate signs after moves
|
||||||
|
move ::= (pawn | nonpawn | castle) [+#]?
|
||||||
|
|
||||||
|
pawn ::= ...
|
||||||
|
nonpawn ::= ...
|
||||||
|
castle ::= ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Non-Terminals and Terminals
|
||||||
|
|
||||||
|
Non-terminal symbols (rule names) stand for a pattern of terminals and other non-terminals. They are required to be a dashed lowercase word, like `move`, `castle`, or `check-mate`.
|
||||||
|
|
||||||
|
Terminals are actual characters ([code points](https://en.wikipedia.org/wiki/Code_point)). They can be specified as a sequence like `"1"` or `"O-O"` or as ranges like `[1-9]` or `[NBKQR]`.
|
||||||
|
|
||||||
|
## Characters and character ranges
|
||||||
|
|
||||||
|
Terminals support the full range of Unicode. Unicode characters can be specified directly in the grammar, for example `hiragana ::= [ぁ-ゟ]`, or with escapes: 8-bit (`\xXX`), 16-bit (`\uXXXX`) or 32-bit (`\UXXXXXXXX`).
|
||||||
|
|
||||||
|
Character ranges can be negated with `^`:
|
||||||
|
```
|
||||||
|
single-line ::= [^\n]+ "\n"`
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sequences and Alternatives
|
||||||
|
|
||||||
|
The order of symbols in a sequence matter. For example, in `"1. " move " " move "\n"`, the `"1. "` must come before the first `move`, etc.
|
||||||
|
|
||||||
|
Alternatives, denoted by `|`, give different sequences that are acceptable. For example, in `move ::= pawn | nonpawn | castle`, `move` can be a `pawn` move, a `nonpawn` move, or a `castle`.
|
||||||
|
|
||||||
|
Parentheses `()` can be used to group sequences, which allows for embedding alternatives in a larger rule or applying repetition and optptional symbols (below) to a sequence.
|
||||||
|
|
||||||
|
## Repetition and Optional Symbols
|
||||||
|
|
||||||
|
- `*` after a symbol or sequence means that it can be repeated zero or more times.
|
||||||
|
- `+` denotes that the symbol or sequence should appear one or more times.
|
||||||
|
- `?` makes the preceding symbol or sequence optional.
|
||||||
|
|
||||||
|
## Comments and newlines
|
||||||
|
|
||||||
|
Comments can be specified with `#`:
|
||||||
|
```
|
||||||
|
# defines optional whitspace
|
||||||
|
ws ::= [ \t\n]+
|
||||||
|
```
|
||||||
|
|
||||||
|
Newlines are allowed between rules and between symbols or sequences nested inside parentheses. Additionally, a newline after an alternate marker `|` will continue the current rule, even outside of parentheses.
|
||||||
|
|
||||||
|
## The root rule
|
||||||
|
|
||||||
|
In a full grammar, the `root` rule always defines the starting point of the grammar. In other words, it specifies what the entire output must match.
|
||||||
|
|
||||||
|
```
|
||||||
|
# a grammar for lists
|
||||||
|
root ::= ("- " item)+
|
||||||
|
item ::= [^\n]+ "\n"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Next steps
|
||||||
|
|
||||||
|
This guide provides a brief overview. Check out the GBNF files in this directory (`grammars/`) for examples of full grammars. You can try them out with:
|
||||||
|
```
|
||||||
|
./main -m <model> --grammar-file grammars/some-grammar.gbnf -p 'Some prompt'
|
||||||
|
```
|
164
k_quants.c
164
k_quants.c
|
@ -77,6 +77,11 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
|
||||||
}
|
}
|
||||||
return 1/iscale;
|
return 1/iscale;
|
||||||
}
|
}
|
||||||
|
bool return_early = false;
|
||||||
|
if (rmse_type < 0) {
|
||||||
|
rmse_type = -rmse_type;
|
||||||
|
return_early = true;
|
||||||
|
}
|
||||||
int weight_type = rmse_type%2;
|
int weight_type = rmse_type%2;
|
||||||
float sumlx = 0;
|
float sumlx = 0;
|
||||||
float suml2 = 0;
|
float suml2 = 0;
|
||||||
|
@ -89,56 +94,9 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
|
||||||
suml2 += w*l*l;
|
suml2 += w*l*l;
|
||||||
}
|
}
|
||||||
float scale = sumlx/suml2;
|
float scale = sumlx/suml2;
|
||||||
|
if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
|
||||||
float best = scale * sumlx;
|
float best = scale * sumlx;
|
||||||
for (int itry = 0; itry < 3; ++itry) {
|
for (int is = -9; is <= 9; ++is) {
|
||||||
iscale = 1/scale;
|
|
||||||
float slx = 0;
|
|
||||||
float sl2 = 0;
|
|
||||||
bool changed = false;
|
|
||||||
for (int i = 0; i < n; ++i) {
|
|
||||||
int l = nearest_int(iscale * x[i]);
|
|
||||||
l = MAX(-nmax, MIN(nmax-1, l));
|
|
||||||
if (l + nmax != L[i]) { changed = true; }
|
|
||||||
float w = weight_type == 1 ? x[i] * x[i] : 1.f;
|
|
||||||
slx += w*x[i]*l;
|
|
||||||
sl2 += w*l*l;
|
|
||||||
}
|
|
||||||
if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; }
|
|
||||||
for (int i = 0; i < n; ++i) {
|
|
||||||
int l = nearest_int(iscale * x[i]);
|
|
||||||
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
|
|
||||||
}
|
|
||||||
sumlx = slx; suml2 = sl2;
|
|
||||||
scale = sumlx/suml2;
|
|
||||||
best = scale * sumlx;
|
|
||||||
}
|
|
||||||
for (int itry = 0; itry < 5; ++itry) {
|
|
||||||
int n_changed = 0;
|
|
||||||
for (int i = 0; i < n; ++i) {
|
|
||||||
float w = weight_type == 1 ? x[i]*x[i] : 1;
|
|
||||||
int l = L[i] - nmax;
|
|
||||||
float slx = sumlx - w*x[i]*l;
|
|
||||||
if (slx > 0) {
|
|
||||||
float sl2 = suml2 - w*l*l;
|
|
||||||
int new_l = nearest_int(x[i] * sl2 / slx);
|
|
||||||
new_l = MAX(-nmax, MIN(nmax-1, new_l));
|
|
||||||
if (new_l != l) {
|
|
||||||
slx += w*x[i]*new_l;
|
|
||||||
sl2 += w*new_l*new_l;
|
|
||||||
if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
|
|
||||||
L[i] = nmax + new_l; sumlx = slx; suml2 = sl2;
|
|
||||||
scale = sumlx / suml2; best = scale * sumlx;
|
|
||||||
++n_changed;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!n_changed) { break; }
|
|
||||||
}
|
|
||||||
if (rmse_type < 3) {
|
|
||||||
return scale;
|
|
||||||
}
|
|
||||||
for (int is = -4; is <= 4; ++is) {
|
|
||||||
if (is == 0) {
|
if (is == 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -221,12 +179,17 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
|
||||||
return 1/iscale;
|
return 1/iscale;
|
||||||
}
|
}
|
||||||
|
|
||||||
static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) {
|
static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
|
||||||
|
int ntry, float alpha) {
|
||||||
float min = x[0];
|
float min = x[0];
|
||||||
float max = x[0];
|
float max = x[0];
|
||||||
|
float sum_x = 0;
|
||||||
|
float sum_x2 = 0;
|
||||||
for (int i = 1; i < n; ++i) {
|
for (int i = 1; i < n; ++i) {
|
||||||
if (x[i] < min) min = x[i];
|
if (x[i] < min) min = x[i];
|
||||||
if (x[i] > max) max = x[i];
|
if (x[i] > max) max = x[i];
|
||||||
|
sum_x += x[i];
|
||||||
|
sum_x2 += x[i]*x[i];
|
||||||
}
|
}
|
||||||
if (max == min) {
|
if (max == min) {
|
||||||
for (int i = 0; i < n; ++i) L[i] = 0;
|
for (int i = 0; i < n; ++i) L[i] = 0;
|
||||||
|
@ -254,7 +217,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
sum += x[i] - scale*L[i];
|
sum += x[i] - scale*L[i];
|
||||||
}
|
}
|
||||||
min = sum/n;
|
min = alpha*min + (1 - alpha)*sum/n;
|
||||||
if (min > 0) min = 0;
|
if (min > 0) min = 0;
|
||||||
iscale = 1/scale;
|
iscale = 1/scale;
|
||||||
if (!did_change) break;
|
if (!did_change) break;
|
||||||
|
@ -263,6 +226,82 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
|
||||||
return scale;
|
return scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
|
||||||
|
uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
|
||||||
|
float rmin, float rdelta, int nstep, bool use_mad) {
|
||||||
|
float min = x[0];
|
||||||
|
float max = x[0];
|
||||||
|
float sum_w = weights[0];
|
||||||
|
float sum_x = sum_w * x[0];
|
||||||
|
for (int i = 1; i < n; ++i) {
|
||||||
|
if (x[i] < min) min = x[i];
|
||||||
|
if (x[i] > max) max = x[i];
|
||||||
|
float w = weights[i];
|
||||||
|
sum_w += w;
|
||||||
|
sum_x += w * x[i];
|
||||||
|
}
|
||||||
|
if (min > 0) min = 0;
|
||||||
|
if (max == min) {
|
||||||
|
for (int i = 0; i < n; ++i) L[i] = 0;
|
||||||
|
*the_min = -min;
|
||||||
|
return 0.f;
|
||||||
|
}
|
||||||
|
float iscale = nmax/(max - min);
|
||||||
|
float scale = 1/iscale;
|
||||||
|
float best_mad = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale*(x[i] - min));
|
||||||
|
L[i] = MAX(0, MIN(nmax, l));
|
||||||
|
float diff = scale * L[i] + min - x[i];
|
||||||
|
diff = use_mad ? fabsf(diff) : diff * diff;
|
||||||
|
float w = weights[i];
|
||||||
|
best_mad += w * diff;
|
||||||
|
}
|
||||||
|
if (nstep < 1) {
|
||||||
|
*the_min = -min;
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
for (int is = 0; is <= nstep; ++is) {
|
||||||
|
iscale = (rmin + rdelta*is + nmax)/(max - min);
|
||||||
|
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale*(x[i] - min));
|
||||||
|
l = MAX(0, MIN(nmax, l));
|
||||||
|
Laux[i] = l;
|
||||||
|
float w = weights[i];
|
||||||
|
sum_l += w*l;
|
||||||
|
sum_l2 += w*l*l;
|
||||||
|
sum_xl += w*l*x[i];
|
||||||
|
}
|
||||||
|
float D = sum_w * sum_l2 - sum_l * sum_l;
|
||||||
|
if (D > 0) {
|
||||||
|
float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
|
||||||
|
float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
|
||||||
|
if (this_min > 0) {
|
||||||
|
this_min = 0;
|
||||||
|
this_scale = sum_xl / sum_l2;
|
||||||
|
}
|
||||||
|
float mad = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float diff = this_scale * Laux[i] + this_min - x[i];
|
||||||
|
diff = use_mad ? fabsf(diff) : diff * diff;
|
||||||
|
float w = weights[i];
|
||||||
|
mad += w * diff;
|
||||||
|
}
|
||||||
|
if (mad < best_mad) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
L[i] = Laux[i];
|
||||||
|
}
|
||||||
|
best_mad = mad;
|
||||||
|
scale = this_scale;
|
||||||
|
min = this_min;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*the_min = -min;
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
||||||
if (j < 4) {
|
if (j < 4) {
|
||||||
|
@ -281,6 +320,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
uint8_t L[QK_K];
|
uint8_t L[QK_K];
|
||||||
|
uint8_t Laux[16];
|
||||||
|
float weights[16];
|
||||||
float mins[QK_K/16];
|
float mins[QK_K/16];
|
||||||
float scales[QK_K/16];
|
float scales[QK_K/16];
|
||||||
|
|
||||||
|
@ -291,7 +332,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
|
||||||
float max_scale = 0; // as we are deducting the min, scales are always positive
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/16; ++j) {
|
for (int j = 0; j < QK_K/16; ++j) {
|
||||||
scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5);
|
for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
|
||||||
|
scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
|
||||||
float scale = scales[j];
|
float scale = scales[j];
|
||||||
if (scale > max_scale) {
|
if (scale > max_scale) {
|
||||||
max_scale = scale;
|
max_scale = scale;
|
||||||
|
@ -637,6 +679,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
uint8_t L[QK_K];
|
uint8_t L[QK_K];
|
||||||
|
uint8_t Laux[32];
|
||||||
|
float weights[32];
|
||||||
float mins[QK_K/32];
|
float mins[QK_K/32];
|
||||||
float scales[QK_K/32];
|
float scales[QK_K/32];
|
||||||
|
|
||||||
|
@ -645,7 +689,12 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
||||||
float max_scale = 0; // as we are deducting the min, scales are always positive
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5);
|
//scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
||||||
|
float sum_x2 = 0;
|
||||||
|
for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
|
||||||
|
float av_x = sqrtf(sum_x2/32);
|
||||||
|
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
|
||||||
|
scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
|
||||||
float scale = scales[j];
|
float scale = scales[j];
|
||||||
if (scale > max_scale) {
|
if (scale > max_scale) {
|
||||||
max_scale = scale;
|
max_scale = scale;
|
||||||
|
@ -798,6 +847,8 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
||||||
uint8_t L[QK_K];
|
uint8_t L[QK_K];
|
||||||
float mins[QK_K/32];
|
float mins[QK_K/32];
|
||||||
float scales[QK_K/32];
|
float scales[QK_K/32];
|
||||||
|
float weights[32];
|
||||||
|
uint8_t Laux[32];
|
||||||
#else
|
#else
|
||||||
int8_t L[QK_K];
|
int8_t L[QK_K];
|
||||||
float scales[QK_K/16];
|
float scales[QK_K/16];
|
||||||
|
@ -810,7 +861,12 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
||||||
float max_scale = 0; // as we are deducting the min, scales are always positive
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 5);
|
//scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
||||||
|
float sum_x2 = 0;
|
||||||
|
for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
|
||||||
|
float av_x = sqrtf(sum_x2/32);
|
||||||
|
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
|
||||||
|
scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
|
||||||
float scale = scales[j];
|
float scale = scales[j];
|
||||||
if (scale > max_scale) {
|
if (scale > max_scale) {
|
||||||
max_scale = scale;
|
max_scale = scale;
|
||||||
|
|
293
llama.cpp
293
llama.cpp
|
@ -10,13 +10,7 @@
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#if !defined(GGML_USE_CUBLAS)
|
|
||||||
#include "ggml-alloc.h"
|
#include "ggml-alloc.h"
|
||||||
# define LLAMA_USE_ALLOCATOR
|
|
||||||
#else
|
|
||||||
# define LLAMA_USE_SCRATCH
|
|
||||||
# define LLAMA_MAX_SCRATCH_BUFFERS 16
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
# include "ggml-cuda.h"
|
# include "ggml-cuda.h"
|
||||||
|
@ -588,14 +582,6 @@ struct llama_state {
|
||||||
|
|
||||||
static llama_state g_state;
|
static llama_state g_state;
|
||||||
|
|
||||||
//
|
|
||||||
// memory sizes (calculated for n_batch == 512)
|
|
||||||
//
|
|
||||||
|
|
||||||
// computed for n_ctx == 2048
|
|
||||||
// TODO: dynamically determine these sizes
|
|
||||||
// needs modifications in ggml
|
|
||||||
|
|
||||||
// available llama models
|
// available llama models
|
||||||
enum e_model {
|
enum e_model {
|
||||||
MODEL_UNKNOWN,
|
MODEL_UNKNOWN,
|
||||||
|
@ -610,76 +596,6 @@ enum e_model {
|
||||||
static const size_t kB = 1024;
|
static const size_t kB = 1024;
|
||||||
static const size_t MB = 1024*1024;
|
static const size_t MB = 1024*1024;
|
||||||
|
|
||||||
static std::map<e_model, size_t> MEM_REQ_SCRATCH0(int n_ctx)
|
|
||||||
{
|
|
||||||
std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, ((size_t) n_ctx / 16ull + 92ull) * MB },
|
|
||||||
{ MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB },
|
|
||||||
{ MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
|
|
||||||
{ MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
|
|
||||||
{ MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
|
|
||||||
{ MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
|
|
||||||
{
|
|
||||||
static std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, 128ull * MB },
|
|
||||||
{ MODEL_7B, 160ull * MB },
|
|
||||||
{ MODEL_13B, 192ull * MB },
|
|
||||||
{ MODEL_30B, 256ull * MB },
|
|
||||||
{ MODEL_65B, 384ull * MB }, // guess
|
|
||||||
{ MODEL_70B, 304ull * MB },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// used to store the compute graph tensors + non-scratch data
|
|
||||||
static const std::map<e_model, size_t> & MEM_REQ_EVAL()
|
|
||||||
{
|
|
||||||
static std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, 8ull * MB },
|
|
||||||
{ MODEL_7B, 10ull * MB },
|
|
||||||
{ MODEL_13B, 12ull * MB },
|
|
||||||
{ MODEL_30B, 16ull * MB },
|
|
||||||
{ MODEL_65B, 24ull * MB }, // guess
|
|
||||||
{ MODEL_70B, 24ull * MB },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// amount of VRAM needed per batch size to hold temporary results
|
|
||||||
// the values for 3b are not derived from testing but instead chosen conservatively
|
|
||||||
static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
|
|
||||||
{
|
|
||||||
static std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, 512ull * kB },
|
|
||||||
{ MODEL_7B, 512ull * kB },
|
|
||||||
{ MODEL_13B, 640ull * kB },
|
|
||||||
{ MODEL_30B, 768ull * kB },
|
|
||||||
{ MODEL_65B, 1280ull * kB },
|
|
||||||
{ MODEL_70B, 1280ull * kB },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// amount of VRAM needed per batch size and context to hold temporary results
|
|
||||||
// the values for 3b are not derived from testing but instead chosen conservatively
|
|
||||||
static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
|
|
||||||
{
|
|
||||||
static std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, 128ull },
|
|
||||||
{ MODEL_7B, 128ull },
|
|
||||||
{ MODEL_13B, 160ull },
|
|
||||||
{ MODEL_30B, 208ull },
|
|
||||||
{ MODEL_65B, 256ull },
|
|
||||||
{ MODEL_70B, 256ull },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// default hparams (LLaMA 7B)
|
// default hparams (LLaMA 7B)
|
||||||
struct llama_hparams {
|
struct llama_hparams {
|
||||||
uint32_t n_vocab = 32000;
|
uint32_t n_vocab = 32000;
|
||||||
|
@ -787,7 +703,7 @@ struct llama_vocab {
|
||||||
// default LLaMA special tokens
|
// default LLaMA special tokens
|
||||||
id special_bos_id = 1;
|
id special_bos_id = 1;
|
||||||
id special_eos_id = 2;
|
id special_eos_id = 2;
|
||||||
id special_unk_id = -1;
|
id special_unk_id = 0;
|
||||||
id special_sep_id = -1;
|
id special_sep_id = -1;
|
||||||
id special_pad_id = -1;
|
id special_pad_id = -1;
|
||||||
|
|
||||||
|
@ -857,11 +773,9 @@ struct llama_context {
|
||||||
ggml_metal_free(ctx_metal);
|
ggml_metal_free(ctx_metal);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
if (alloc) {
|
if (alloc) {
|
||||||
ggml_allocr_free(alloc);
|
ggml_allocr_free(alloc);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
@ -901,17 +815,8 @@ struct llama_context {
|
||||||
// memory buffers used to evaluate the model
|
// memory buffers used to evaluate the model
|
||||||
llama_buffer buf_compute;
|
llama_buffer buf_compute;
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
llama_buffer buf_alloc;
|
llama_buffer buf_alloc;
|
||||||
ggml_allocr * alloc = NULL;
|
ggml_allocr * alloc = NULL;
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef LLAMA_USE_SCRATCH
|
|
||||||
llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
|
|
||||||
|
|
||||||
int buf_last = 0;
|
|
||||||
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
ggml_metal_context * ctx_metal = NULL;
|
ggml_metal_context * ctx_metal = NULL;
|
||||||
|
@ -920,37 +825,6 @@ struct llama_context {
|
||||||
#ifdef GGML_USE_MPI
|
#ifdef GGML_USE_MPI
|
||||||
ggml_mpi_context * ctx_mpi = NULL;
|
ggml_mpi_context * ctx_mpi = NULL;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void use_buf(struct ggml_context * ctx, int i) { // NOLINT
|
|
||||||
#if defined(LLAMA_USE_SCRATCH)
|
|
||||||
size_t last_size = 0;
|
|
||||||
|
|
||||||
if (i == -1) {
|
|
||||||
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
|
||||||
} else {
|
|
||||||
auto & buf = buf_scratch[i];
|
|
||||||
last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.data, });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (buf_last >= 0) {
|
|
||||||
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
buf_last = i;
|
|
||||||
#else
|
|
||||||
(void) i;
|
|
||||||
(void) ctx;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t get_buf_max_mem(int i) { // NOLINT
|
|
||||||
#if defined(LLAMA_USE_SCRATCH)
|
|
||||||
return buf_max_size[i];
|
|
||||||
#else
|
|
||||||
(void) i;
|
|
||||||
return 0;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -1121,6 +995,16 @@ struct llama_model_loader {
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this is a way to mark that we have "guessed" the file type
|
||||||
|
ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
|
||||||
|
|
||||||
|
{
|
||||||
|
const int kid = gguf_find_key(ctx_gguf, "general.file_type");
|
||||||
|
if (kid >= 0) {
|
||||||
|
ftype = (llama_ftype) gguf_get_val_u32(ctx_gguf, kid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; i++) {
|
for (int i = 0; i < n_kv; i++) {
|
||||||
const char * name = gguf_get_key(ctx_gguf, i);
|
const char * name = gguf_get_key(ctx_gguf, i);
|
||||||
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
|
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
|
||||||
|
@ -1323,7 +1207,11 @@ struct llama_model_loader {
|
||||||
// load LLaMA models
|
// load LLaMA models
|
||||||
//
|
//
|
||||||
|
|
||||||
const char * llama_model_ftype_name(enum llama_ftype ftype) {
|
std::string llama_model_ftype_name(enum llama_ftype ftype) {
|
||||||
|
if (ftype & LLAMA_FTYPE_GUESSED) {
|
||||||
|
return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
|
||||||
|
}
|
||||||
|
|
||||||
switch (ftype) {
|
switch (ftype) {
|
||||||
case LLAMA_FTYPE_ALL_F32: return "all F32";
|
case LLAMA_FTYPE_ALL_F32: return "all F32";
|
||||||
case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16";
|
case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16";
|
||||||
|
@ -1552,7 +1440,7 @@ static void llama_model_load_internal(
|
||||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
|
||||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
|
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
|
||||||
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
||||||
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype));
|
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
|
||||||
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml->n_elements*1e-9);
|
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml->n_elements*1e-9);
|
||||||
|
|
||||||
// general kv
|
// general kv
|
||||||
|
@ -1620,7 +1508,6 @@ static void llama_model_load_internal(
|
||||||
|
|
||||||
// prepare memory for the weights
|
// prepare memory for the weights
|
||||||
size_t vram_weights = 0;
|
size_t vram_weights = 0;
|
||||||
size_t vram_scratch = 0;
|
|
||||||
{
|
{
|
||||||
const uint32_t n_embd = hparams.n_embd;
|
const uint32_t n_embd = hparams.n_embd;
|
||||||
const uint32_t n_embd_gqa = hparams.n_embd_gqa();
|
const uint32_t n_embd_gqa = hparams.n_embd_gqa();
|
||||||
|
@ -1701,13 +1588,6 @@ static void llama_model_load_internal(
|
||||||
ctx_size +
|
ctx_size +
|
||||||
mmapped_size - vram_weights; // weights in VRAM not in memory
|
mmapped_size - vram_weights; // weights in VRAM not in memory
|
||||||
|
|
||||||
#ifndef LLAMA_USE_ALLOCATOR
|
|
||||||
mem_required +=
|
|
||||||
MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
|
|
||||||
MEM_REQ_SCRATCH1().at(model.type) +
|
|
||||||
MEM_REQ_EVAL().at(model.type);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// this is the memory required by one llama_state
|
// this is the memory required by one llama_state
|
||||||
const size_t mem_required_state =
|
const size_t mem_required_state =
|
||||||
scale*hparams.kv_size();
|
scale*hparams.kv_size();
|
||||||
|
@ -1715,24 +1595,7 @@ static void llama_model_load_internal(
|
||||||
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
||||||
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
||||||
|
|
||||||
(void) vram_scratch;
|
|
||||||
(void) n_batch;
|
(void) n_batch;
|
||||||
#ifdef GGML_USE_CUBLAS
|
|
||||||
if (low_vram) {
|
|
||||||
LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
|
|
||||||
ggml_cuda_set_scratch_size(0); // disable scratch
|
|
||||||
} else {
|
|
||||||
const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type);
|
|
||||||
const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type);
|
|
||||||
vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context);
|
|
||||||
ggml_cuda_set_scratch_size(vram_scratch);
|
|
||||||
if (n_gpu_layers > 0) {
|
|
||||||
LLAMA_LOG_INFO("%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n",
|
|
||||||
__func__, vram_scratch_base / kB, vram_scratch_per_context,
|
|
||||||
(vram_scratch + MB - 1) / MB); // round up
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif // GGML_USE_CUBLAS
|
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
||||||
|
@ -1769,8 +1632,8 @@ static void llama_model_load_internal(
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
|
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
|
||||||
__func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
|
__func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
|
||||||
LLAMA_LOG_INFO("%s: total VRAM used: %zu MB\n",
|
LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n",
|
||||||
__func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up
|
__func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up
|
||||||
#else
|
#else
|
||||||
(void) n_gpu_layers;
|
(void) n_gpu_layers;
|
||||||
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
|
@ -1875,9 +1738,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
/*.no_alloc =*/ false,
|
/*.no_alloc =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
params.no_alloc = true;
|
params.no_alloc = true;
|
||||||
#endif
|
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
|
|
||||||
|
@ -1889,14 +1750,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
if (tokens) {
|
if (tokens) {
|
||||||
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
|
||||||
#endif
|
|
||||||
ggml_set_name(inp_tokens, "inp_tokens");
|
ggml_set_name(inp_tokens, "inp_tokens");
|
||||||
|
|
||||||
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
||||||
|
@ -1907,14 +1764,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
|
|
||||||
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_alloc(lctx.alloc, inpL);
|
ggml_allocr_alloc(lctx.alloc, inpL);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||||
|
@ -1931,25 +1784,21 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
if (n_gpu_layers > n_layer) {
|
if (n_gpu_layers > n_layer) {
|
||||||
offload_func_nr = ggml_cuda_assign_buffers;
|
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
|
||||||
}
|
}
|
||||||
if (n_gpu_layers > n_layer + 1) {
|
if (n_gpu_layers > n_layer + 1) {
|
||||||
offload_func_v = ggml_cuda_assign_buffers;
|
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
|
||||||
}
|
}
|
||||||
if (n_gpu_layers > n_layer + 2) {
|
if (n_gpu_layers > n_layer + 2) {
|
||||||
offload_func_kq = ggml_cuda_assign_buffers;
|
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
|
||||||
}
|
}
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
|
|
||||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
|
||||||
#endif
|
|
||||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
@ -1959,14 +1808,12 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
if (il >= i_gpu_start) {
|
if (il >= i_gpu_start) {
|
||||||
offload_func = ggml_cuda_assign_buffers;
|
offload_func = ggml_cuda_assign_buffers_no_alloc;
|
||||||
}
|
}
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
|
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
lctx.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
||||||
|
@ -2104,8 +1951,6 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
ggml_set_name(cur, "result_wo");
|
ggml_set_name(cur, "result_wo");
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
||||||
offload_func(inpFF);
|
offload_func(inpFF);
|
||||||
ggml_set_name(inpFF, "inpFF");
|
ggml_set_name(inpFF, "inpFF");
|
||||||
|
@ -2160,8 +2005,6 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
||||||
|
@ -2178,8 +2021,6 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
ggml_set_name(cur, "result_output");
|
ggml_set_name(cur, "result_output");
|
||||||
|
|
||||||
lctx.use_buf(ctx0, -1);
|
|
||||||
|
|
||||||
// logits -> probs
|
// logits -> probs
|
||||||
//cur = ggml_soft_max_inplace(ctx0, cur);
|
//cur = ggml_soft_max_inplace(ctx0, cur);
|
||||||
|
|
||||||
|
@ -2189,15 +2030,6 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if 0
|
|
||||||
LLAMA_LOG_INFO("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
|
|
||||||
ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
||||||
lctx.get_buf_max_mem(0)/1024.0/1024.0,
|
|
||||||
lctx.get_buf_max_mem(1)/1024.0/1024.0,
|
|
||||||
lctx.work_buffer.size()/1024.0/1024.0,
|
|
||||||
n_past, N);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
|
@ -2248,14 +2080,26 @@ static bool llama_eval_internal(
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_vocab = hparams.n_vocab;
|
const int64_t n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_reset(lctx.alloc);
|
ggml_allocr_reset(lctx.alloc);
|
||||||
#endif
|
|
||||||
|
|
||||||
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
|
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
for (int i = 0; i < gf->n_leafs; i++) {
|
||||||
|
ggml_tensor * node = gf->leafs[i];
|
||||||
|
if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
|
||||||
|
ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < gf->n_nodes; i++) {
|
||||||
|
ggml_tensor * node = gf->nodes[i];
|
||||||
|
if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
|
||||||
|
ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||||
|
@ -2409,18 +2253,11 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string llama_escape_whitespace(const std::string& text) {
|
static std::string llama_escape_whitespace(const std::string& text) {
|
||||||
std::string result;
|
std::string result = "\xe2\x96\x81";
|
||||||
bool escaping = false;
|
|
||||||
result += "\xe2\x96\x81";
|
|
||||||
for (size_t offs = 0; offs < text.length(); ++offs) {
|
for (size_t offs = 0; offs < text.length(); ++offs) {
|
||||||
if (text[offs] == ' ') {
|
if (text[offs] == ' ') {
|
||||||
if (!escaping) {
|
|
||||||
result += "\xe2\x96\x81";
|
result += "\xe2\x96\x81";
|
||||||
escaping = true;
|
} else {
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
escaping = false;
|
|
||||||
result += text[offs];
|
result += text[offs];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3620,6 +3457,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
// copy the KV pairs from the input file
|
// copy the KV pairs from the input file
|
||||||
gguf_set_kv (ctx_out, model_loader->ctx_gguf);
|
gguf_set_kv (ctx_out, model_loader->ctx_gguf);
|
||||||
gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
|
gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
|
||||||
|
gguf_set_val_u32(ctx_out, "general.file_type", ftype);
|
||||||
|
|
||||||
#ifdef GGML_USE_K_QUANTS
|
#ifdef GGML_USE_K_QUANTS
|
||||||
int n_attention_wv = 0;
|
int n_attention_wv = 0;
|
||||||
|
@ -3717,24 +3555,40 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
new_type = GGML_TYPE_Q6_K;
|
new_type = GGML_TYPE_Q6_K;
|
||||||
}
|
}
|
||||||
} else if (name.find("attn_v.weight") != std::string::npos) {
|
} else if (name.find("attn_v.weight") != std::string::npos) {
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||||
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
||||||
|
new_type = i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
|
||||||
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
||||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
||||||
use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K;
|
use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K;
|
||||||
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
|
||||||
else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
|
else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
|
||||||
(i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
|
(i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
|
||||||
++i_attention_wv;
|
++i_attention_wv;
|
||||||
} else if (name.find("ffn_down.weight") != std::string::npos) {
|
} else if (name.find("ffn_down.weight") != std::string::npos) {
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||||
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
||||||
|
new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
|
||||||
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
||||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
||||||
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
|
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
|
||||||
//else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K;
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
|
||||||
++i_feed_forward_w2;
|
++i_feed_forward_w2;
|
||||||
} else if (name.find("attn_output.weight") != std::string::npos) {
|
} else if (name.find("attn_output.weight") != std::string::npos) {
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
|
||||||
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K;
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
||||||
}
|
}
|
||||||
|
else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) {
|
||||||
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||||
|
}
|
||||||
|
// This can be used to reduce the size of the Q5_K_S model.
|
||||||
|
// The associated PPL increase is fully in line with the size reduction
|
||||||
|
//else {
|
||||||
|
// if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
|
||||||
|
//}
|
||||||
bool convert_incompatible_tensor = false;
|
bool convert_incompatible_tensor = false;
|
||||||
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
|
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
|
||||||
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) {
|
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) {
|
||||||
|
@ -4319,7 +4173,6 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->embedding.resize(hparams.n_embd);
|
ctx->embedding.resize(hparams.n_embd);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
{
|
{
|
||||||
static const size_t tensor_alignment = 32;
|
static const size_t tensor_alignment = 32;
|
||||||
// the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
|
// the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
|
||||||
|
@ -4350,13 +4203,6 @@ struct llama_context * llama_new_context_with_model(
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
|
LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
|
||||||
|
|
||||||
// debug - for comparison with scratch buffer
|
|
||||||
//size_t prev_req =
|
|
||||||
// MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) +
|
|
||||||
// MEM_REQ_SCRATCH1().at(ctx->model.type) +
|
|
||||||
// MEM_REQ_EVAL().at(ctx->model.type);
|
|
||||||
//LLAMA_LOG_INFO("%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0);
|
|
||||||
|
|
||||||
// recreate allocator with exact memory requirements
|
// recreate allocator with exact memory requirements
|
||||||
ggml_allocr_free(ctx->alloc);
|
ggml_allocr_free(ctx->alloc);
|
||||||
|
|
||||||
|
@ -4367,16 +4213,17 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
|
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
if (params.low_vram) {
|
||||||
|
LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
|
||||||
|
ggml_cuda_set_scratch_size(0); // disable scratch
|
||||||
|
} else {
|
||||||
|
ggml_cuda_set_scratch_size(alloc_size);
|
||||||
|
LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef LLAMA_USE_SCRATCH
|
|
||||||
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
|
|
||||||
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (params.n_gpu_layers > 0) {
|
if (params.n_gpu_layers > 0) {
|
||||||
|
@ -4471,7 +4318,7 @@ int llama_model_n_embd(const struct llama_model * model) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size) {
|
int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size) {
|
||||||
return snprintf(buf, buf_size, "LLaMA %s %s", llama_model_type_name(model->type), llama_model_ftype_name(model->ftype));
|
return snprintf(buf, buf_size, "LLaMA %s %s", llama_model_type_name(model->type), llama_model_ftype_name(model->ftype).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
int llama_model_quantize(
|
int llama_model_quantize(
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -103,6 +103,8 @@ extern "C" {
|
||||||
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
|
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
|
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
|
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
|
||||||
|
|
||||||
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
cp -rpv ../ggml/src/ggml.c ./ggml.c
|
cp -rpv ../ggml/src/ggml.c ./ggml.c
|
||||||
|
cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
|
||||||
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
|
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
|
||||||
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
|
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
|
||||||
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
||||||
|
@ -9,6 +10,7 @@ cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
|
||||||
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
||||||
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
||||||
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
|
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
|
||||||
|
cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
|
||||||
|
|
||||||
cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp
|
cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp
|
||||||
cp -rpv ../ggml/tests/test-grad0.cpp ./tests/test-grad0.cpp
|
cp -rpv ../ggml/tests/test-grad0.cpp ./tests/test-grad0.cpp
|
||||||
|
|
|
@ -17,6 +17,8 @@ static std::string unescape_whitespace(llama_context* ctx, const std::vector<lla
|
||||||
static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
||||||
static std::map<std::string, std::vector<llama_token>> _k_tests = {
|
static std::map<std::string, std::vector<llama_token>> _k_tests = {
|
||||||
{ " ", {1, 259, }, },
|
{ " ", {1, 259, }, },
|
||||||
|
{ " ", { 1, 1678, }, },
|
||||||
|
{ " ", { 1, 268, }, },
|
||||||
{ "\t", { 1, 29871, 12, }, },
|
{ "\t", { 1, 29871, 12, }, },
|
||||||
{ "\n", { 1, 29871, 13, }, },
|
{ "\n", { 1, 29871, 13, }, },
|
||||||
{ "\t\n", { 1, 29871, 12, 13, }, },
|
{ "\t\n", { 1, 29871, 12, 13, }, },
|
||||||
|
@ -38,6 +40,12 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
||||||
243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598,
|
243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598,
|
||||||
313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681,
|
313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681,
|
||||||
313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, },
|
313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, },
|
||||||
|
{ "Hello", { 1, 15043 }, },
|
||||||
|
{ " Hello", { 1, 29871, 15043 }, },
|
||||||
|
{ " Hello", { 1, 259, 15043 }, },
|
||||||
|
{ " Hello", { 1, 1678, 15043 }, },
|
||||||
|
{ " Hello", { 1, 268, 15043 }, },
|
||||||
|
{ " Hello\n Hello", { 1, 268, 15043, 13, 1678, 15043 }, },
|
||||||
};
|
};
|
||||||
|
|
||||||
return _k_tests;
|
return _k_tests;
|
||||||
|
@ -106,7 +114,8 @@ int main(int argc, char **argv) {
|
||||||
|
|
||||||
if (!correct) {
|
if (!correct) {
|
||||||
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
||||||
fprintf(stderr, "%s : detokenized to: '%s'\n", __func__, unescape_whitespace(ctx, test_kv.second).c_str());
|
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
|
||||||
|
unescape_whitespace(ctx, res).c_str(), unescape_whitespace(ctx, test_kv.second).c_str());
|
||||||
fprintf(stderr, "%s : expected tokens: ", __func__);
|
fprintf(stderr, "%s : expected tokens: ", __func__);
|
||||||
for (const auto & t : test_kv.second) {
|
for (const auto & t : test_kv.second) {
|
||||||
fprintf(stderr, "%6d, ", t);
|
fprintf(stderr, "%6d, ", t);
|
||||||
|
|
|
@ -11,18 +11,11 @@
|
||||||
#include <locale>
|
#include <locale>
|
||||||
|
|
||||||
static std::string escape_whitespace(const std::string& text) {
|
static std::string escape_whitespace(const std::string& text) {
|
||||||
std::string result;
|
std::string result = "\xe2\x96\x81";
|
||||||
bool escaping = false;
|
|
||||||
result += "\xe2\x96\x81";
|
|
||||||
for (size_t offs = 0; offs < text.length(); ++offs) {
|
for (size_t offs = 0; offs < text.length(); ++offs) {
|
||||||
if (text[offs] == ' ') {
|
if (text[offs] == ' ') {
|
||||||
if (!escaping) {
|
|
||||||
result += "\xe2\x96\x81";
|
result += "\xe2\x96\x81";
|
||||||
escaping = true;
|
} else {
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
escaping = false;
|
|
||||||
result += text[offs];
|
result += text[offs];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue