From 95078cc554fe03d4512363c7e4dec963f0047c72 Mon Sep 17 00:00:00 2001 From: ubik2 Date: Mon, 8 May 2023 04:54:26 -0700 Subject: [PATCH 1/6] convert: add ability to convert safetensors files (#1276) * when loading a safetensors file, ignore the metadata header * check for safetensors files first, and only use PyTorch versions when safetensors aren't available --- convert.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/convert.py b/convert.py index 126beaabc..8f4f0399e 100644 --- a/convert.py +++ b/convert.py @@ -766,7 +766,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) description = f'safetensors begin={begin} end={end} type={data_type} path={path}' return LazyTensor(load, shape, data_type, description) - model = {name: convert(info) for (name, info) in header.items()} + model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) @@ -1051,8 +1051,12 @@ def load_some_model(path: Path) -> ModelPlus: '''Load a model of any supported format.''' # Be extra-friendly and accept either a file or a directory: if path.is_dir(): - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] - files = [file for glob in globs for file in path.glob(glob)] + # Check if it's a set of safetensors files first + files = list(path.glob("model-00001-of-*.safetensors")) + if not files: + # Try the PyTorch patterns too, with lower priority + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] + files = [file for glob in globs for file in path.glob(glob)] if not files: # Try GGML too, but with lower priority, since if both a non-GGML # model and a GGML model exist in the same directory, we assume the From f9a6364912fd0463fddfdbc9ef9f79fdc281570d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 May 2023 17:41:54 +0300 Subject: [PATCH 2/6] llama : require first token to be BOS (#1303) * llama : require first token to be BOS * scripts : add ppl-run-all.sh * perplexity : add BOS for each chunk * readme : update perplexity values after BOS fix * perplexity : add clarifying comments --- .gitignore | 1 + README.md | 32 +++++--------- examples/common.cpp | 4 +- examples/main/main.cpp | 4 +- examples/perplexity/perplexity.cpp | 70 ++++++++++++++++++++---------- llama.cpp | 12 ++++- scripts/ppl-run-all.sh | 43 ++++++++++++++++++ 7 files changed, 116 insertions(+), 50 deletions(-) create mode 100755 scripts/ppl-run-all.sh diff --git a/.gitignore b/.gitignore index 6f275fea4..a5fef3277 100644 --- a/.gitignore +++ b/.gitignore @@ -43,5 +43,6 @@ zig-out/ zig-cache/ ppl-*.txt +qnt-*.txt examples/jeopardy/results.txt diff --git a/README.md b/README.md index 6cbdcbf83..438748a91 100644 --- a/README.md +++ b/README.md @@ -298,17 +298,25 @@ Several quantization methods are supported. They differ in the resulting model d | Model | Measure | F16 | Q4_0 | Q4_1 | Q4_2 | Q5_0 | Q5_1 | Q8_0 | |------:|--------------|-------:|-------:|-------:|-------:|-------:|-------:|-------:| -| 7B | perplexity | 5.9565 | 6.2103 | 6.1286 | 6.1698 | 6.0139 | 5.9934 | 5.9571 | +| 7B | perplexity | 5.9066 | 6.1620 | 6.0910 | 6.1466 | 5.9862 | 5.9481 | 5.9069 | | 7B | file size | 13.0G | 4.0G | 4.8G | 4.0G | 4.4G | 4.8G | 7.1G | | 7B | ms/tok @ 4th | 128 | 56 | 61 | 84 | 91 | 95 | 75 | | 7B | ms/tok @ 8th | 128 | 47 | 55 | 48 | 53 | 59 | 75 | | 7B | bits/weight | 16.0 | 5.0 | 6.0 | 5.0 | 5.5 | 6.0 | 9.0 | -| 13B | perplexity | 5.2455 | 5.3748 | 5.3471 | 5.3433 | 5.2768 | 5.2582 | 5.2458 | +| 13B | perplexity | 5.2543 | 5.3863 | 5.3607 | 5.3513 | 5.2856 | 5.2706 | 5.2548 | | 13B | file size | 25.0G | 7.6G | 9.1G | 7.6G | 8.4G | 9.1G | 14G | | 13B | ms/tok @ 4th | 239 | 104 | 113 | 160 | 176 | 185 | 141 | | 13B | ms/tok @ 8th | 240 | 85 | 99 | 97 | 108 | 117 | 147 | | 13B | bits/weight | 16.0 | 5.0 | 6.0 | 5.0 | 5.5 | 6.0 | 9.0 | +### Perplexity (measuring model quality) + +You can use the `perplexity` example to measure perplexity over a given prompt (lower perplexity is better). +For more information, see [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity). + +The perplexity measurements in table above are done against the `wikitext2` test dataset (https://paperswithcode.com/dataset/wikitext-2), with context length of 512. +The time per token is measured on a MacBook M1 Pro 32GB RAM using 4 and 8 threads. + ### Interactive mode If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter. @@ -407,26 +415,6 @@ If your issue is with model generation quality, then please at least scan the fo - [Aligning language models to follow instructions](https://openai.com/research/instruction-following) - [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) -### Perplexity (measuring model quality) - -You can use the `perplexity` example to measure perplexity over the given prompt. For more background, see [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity). However, in general, lower perplexity is better for LLMs. - -#### Latest measurements - -The latest perplexity scores for the various model sizes and quantizations are being tracked in [discussion #406](https://github.com/ggerganov/llama.cpp/discussions/406). `llama.cpp` is measuring very well compared to the baseline implementations. Quantization has a small negative impact on quality, but, as you can see, running -13B at q4_0 beats the 7B f16 model by a significant amount. - -All measurements are done against the wikitext2 test dataset (https://paperswithcode.com/dataset/wikitext-2), with default options (512 length context). -Note that changing the context length will have a significant impact on perplexity (longer context = better perplexity). -``` -Perplexity - model options -5.5985 - 13B, q4_0 -5.9565 - 7B, f16 -6.3001 - 7B, q4_1 -6.5949 - 7B, q4_0 -6.5995 - 7B, q4_0, --memory_f16 -``` - #### How to run 1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research diff --git a/examples/common.cpp b/examples/common.cpp index f1c3bae13..6af440272 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -438,8 +438,8 @@ std::string gpt_random_prompt(std::mt19937 & rng) { // TODO: not great allocating this every time std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars - std::vector res(text.size() + (int)add_bos); - int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); + std::vector res(text.size() + (int) add_bos); + const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); assert(n >= 0); res.resize(n); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5ac151e14..045093c72 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -313,7 +313,8 @@ int main(int argc, char ** argv) { if (n_past + (int) embd.size() > n_ctx) { const int n_left = n_past - params.n_keep; - n_past = params.n_keep; + // always keep the first token - BOS + n_past = std::max(1, params.n_keep); // insert n_left/2 tokens at the start of embd from last_n_tokens embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); @@ -331,7 +332,6 @@ int main(int argc, char ** argv) { } // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) - // REVIEW if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; for ( ; i < embd.size(); i++) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 299a19999..9212dee5c 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -25,46 +25,68 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` + // BOS tokens will be added for each chunk before eval auto tokens = ::llama_tokenize(ctx, params.prompt, true); - int count = 0; - int seq_count = tokens.size() / params.n_ctx; - int n_vocab = llama_n_vocab(ctx); + int count = 0; + + const int n_chunk = tokens.size() / params.n_ctx; + const int n_vocab = llama_n_vocab(ctx); + const int n_batch = params.n_batch; double nll = 0.0; - fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch); + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); - for (int i = 0; i < seq_count; ++i) { - int start = i * params.n_ctx; - int end = start + params.n_ctx; + for (int i = 0; i < n_chunk; ++i) { + const int start = i * params.n_ctx; + const int end = start + params.n_ctx; + + const int num_batches = (params.n_ctx + n_batch - 1) / n_batch; std::vector logits; - int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch; - auto start_t = std::chrono::high_resolution_clock::now(); + + const auto t_start = std::chrono::high_resolution_clock::now(); + for (int j = 0; j < num_batches; ++j) { - int batch_start = start + j * params.n_batch; - int batch_size = std::min(end - batch_start, params.n_batch); - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) { + const int batch_start = start + j * n_batch; + const int batch_size = std::min(end - batch_start, n_batch); + + // save original token and restore it after eval + const auto token_org = tokens[batch_start]; + + // add BOS token for the first batch of each chunk + if (j == 0) { + tokens[batch_start] = llama_token_bos(); + } + + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } - auto batch_logits = llama_get_logits(ctx); + + // restore the original token in case it was set to BOS + tokens[batch_start] = token_org; + + const auto batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } - auto end_t = std::chrono::high_resolution_clock::now(); + + const auto t_end = std::chrono::high_resolution_clock::now(); + if (i == 0) { - const float seconds = std::chrono::duration(end_t - start_t).count(); - printf("%.2f seconds per pass - ETA ", seconds); - int total_seconds = (int)(seconds * seq_count); + const float t_total = std::chrono::duration(t_end - t_start).count(); + fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk); if (total_seconds >= 60*60) { - printf("%d hours ", total_seconds / (60*60)); + fprintf(stderr, "%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } - printf("%d minutes\n", total_seconds / 60); + fprintf(stderr, "%d minutes\n", total_seconds / 60); } + // We get the logits for all the tokens in the context window (params.n_ctx) // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, - // calculate the perplexity over the last half the window (so the model always has + // calculate the perplexity over the last half of the window (so the model always has // some context to predict the token). // // We rely on the fact that attention in the forward pass only looks at previous @@ -76,10 +98,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // process the entire prompt. for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) { // Calculate probability of next token, given the previous ones. - std::vector tok_logits( - logits.begin() + j * n_vocab, + const std::vector tok_logits( + logits.begin() + (j + 0) * n_vocab, logits.begin() + (j + 1) * n_vocab); - float prob = softmax(tok_logits)[tokens[start + j + 1]]; + + const float prob = softmax(tok_logits)[tokens[start + j + 1]]; + nll += -std::log(prob); ++count; } diff --git a/llama.cpp b/llama.cpp index c36c6ced6..d54fa502c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1052,6 +1052,13 @@ static bool llama_eval_internal( const int n_tokens, const int n_past, const int n_threads) { + + // enforce that the first token is BOS + if (n_past == 0 && tokens[0] != llama_token_bos()) { + fprintf(stderr, "%s: first token must be BOS\n", __func__); + return false; + } + const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -1482,7 +1489,7 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co } if (bos) { - output.push_back(1); + output.push_back(llama_token_bos()); } tokenizer.tokenize(text, output); @@ -2727,11 +2734,14 @@ int llama_eval( fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } + // get a more accurate load time, upon first eval + // TODO: fix this if (!ctx->has_evaluated_once) { ctx->t_load_us = ggml_time_us() - ctx->t_start_us; ctx->has_evaluated_once = true; } + return 0; } diff --git a/scripts/ppl-run-all.sh b/scripts/ppl-run-all.sh new file mode 100755 index 000000000..28f31ca71 --- /dev/null +++ b/scripts/ppl-run-all.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# +# quantize +# + +# 7B +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q4_0.bin q4_0 2>&1 | tee ../qnt-7b-q4_0.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q4_1.bin q4_1 2>&1 | tee ../qnt-7b-q4_1.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q4_2.bin q4_2 2>&1 | tee ../qnt-7b-q4_2.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q5_0.bin q5_0 2>&1 | tee ../qnt-7b-q5_0.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q5_1.bin q5_1 2>&1 | tee ../qnt-7b-q5_1.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q8_0.bin q8_0 2>&1 | tee ../qnt-7b-q8_0.txt + +# 13B +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q4_0.bin q4_0 2>&1 | tee ../qnt-13b-q4_0.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q4_1.bin q4_1 2>&1 | tee ../qnt-13b-q4_1.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q4_2.bin q4_2 2>&1 | tee ../qnt-13b-q4_2.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q5_0.bin q5_0 2>&1 | tee ../qnt-13b-q5_0.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q5_1.bin q5_1 2>&1 | tee ../qnt-13b-q5_1.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q8_0.bin q8_0 2>&1 | tee ../qnt-13b-q8_0.txt + +# +# perplexity +# + +# 7B +time ./bin/perplexity -m ../models/7B/ggml-model-f16.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-f16.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q4_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q4_0.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q4_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q4_1.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q4_2.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q4_2.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q5_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q5_0.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q5_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q5_1.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q8_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q8_0.txt + +# 13B +time ./bin/perplexity -m ../models/13B/ggml-model-f16.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-f16.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q4_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q4_0.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q4_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q4_1.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q4_2.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q4_2.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q5_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q5_0.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q5_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q5_1.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q8_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q8_0.txt From 003ba2fb4309e2339487564bd249e4fcc8d7ea01 Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Mon, 8 May 2023 16:48:21 +0200 Subject: [PATCH 3/6] llama : fix hparams shadow (#1367) fixes #1363 --- llama.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index d54fa502c..4bba93a11 100644 --- a/llama.cpp +++ b/llama.cpp @@ -970,8 +970,6 @@ static void llama_model_load_internal( // prepare memory for the weights { - const auto & hparams = model.hparams; - const uint32_t n_embd = hparams.n_embd; const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; From fe60904eef4b504685fa0406cb19864ae619fb4f Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Mon, 8 May 2023 21:03:30 +0430 Subject: [PATCH 4/6] readme : add TOC and Pygmalion instructions (#1359) --- README.md | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/README.md b/README.md index 438748a91..f029f06a8 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,39 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ - [Roadmap May 2023](https://github.com/ggerganov/llama.cpp/discussions/1220) - [New quantization methods](https://github.com/ggerganov/llama.cpp#quantization) +
+ Table of Contents +
    +
  1. + Description +
  2. +
  3. + Usage + +
  4. +
  5. Contributing
  6. +
  7. Coding guidelines
  8. +
  9. Docs
  10. +
+
+ ## Description The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quantization on a MacBook @@ -46,6 +79,7 @@ as the main playground for developing new features for the [ggml](https://github - [X] [Vicuna](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5533894) - [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/) - [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy) +- [X] [Pygmalion 7B / Metharme 7B](#using-pygmalion-7b--metharme-7b) **Bindings:** @@ -383,6 +417,19 @@ python3 convert.py models/gpt4all-7B/gpt4all-lora-quantized.bin - The newer GPT4All-J model is not yet supported! +### Using Pygmalion 7B & Metharme 7B + +- Obtain the [LLaMA weights](#obtaining-the-facebook-llama-original-model-and-stanford-alpaca-model-data) +- Obtain the [Pygmalion 7B](https://huggingface.co/PygmalionAI/pygmalion-7b/) or [Metharme 7B](https://huggingface.co/PygmalionAI/metharme-7b) XOR encoded weights +- Convert the LLaMA model with [the latest HF convert script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) +- Merge the XOR files with the converted LLaMA weights by running the [xor_codec](https://huggingface.co/PygmalionAI/pygmalion-7b/blob/main/xor_codec.py) script +- Convert to `ggml` format using the `convert.py` script in this repo: +```bash +python3 convert.py pygmalion-7b/ --outtype q4_1 +``` +> The Pygmalion 7B & Metharme 7B weights are saved in [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) precision. If you wish to convert to `ggml` without quantizating, please specify the `--outtype` as `f32` instead of `f16`. + + ### Obtaining the Facebook LLaMA original model and Stanford Alpaca model data - **Under no circumstances should IPFS, magnet links, or any other links to model downloads be shared anywhere in this repository, including in issues, discussions, or pull requests. They will be immediately deleted.** From 56551bc11f46b2716fdf61bb48ac28414889dc0a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 May 2023 22:52:18 +0300 Subject: [PATCH 5/6] readme : add notice about upcoming breaking change --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index f029f06a8..045f99534 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,14 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ +## ⚠️ TEMPORARY NOTICE ABOUT UPCOMING BREAKING CHANGE ⚠️ + +**The quantization formats will soon be updated: https://github.com/ggerganov/llama.cpp/pull/1305** + +**All `ggml` model files using the old format will not work with the latest `llama.cpp` code after that change is merged** + +--- + **Hot topics:** - [Roadmap May 2023](https://github.com/ggerganov/llama.cpp/discussions/1220) From 41654efea879bbdf4fd794e13335929d4cf0eb90 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Mon, 8 May 2023 19:45:48 -0700 Subject: [PATCH 6/6] Interface improvements and `--multiline-input` (previously `--author-mode`) (#1040) * Interface improvements * Multiline input * Track character width * Works with all characters and control codes + Windows console fixes --- examples/common.cpp | 384 +++++++++++++++++++++++++++++++++++------ examples/common.h | 25 ++- examples/main/main.cpp | 60 +++---- 3 files changed, 374 insertions(+), 95 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 6af440272..23d69e7d5 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -14,20 +14,16 @@ #include #endif -#if defined (_WIN32) +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include #include #include -#pragma comment(lib,"kernel32.lib") -extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); -extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int CodePage, unsigned long dwFlags, - const wchar_t * lpWideCharStr, int cchWideChar, - char * lpMultiByteStr, int cbMultiByte, - const char * lpDefaultChar, bool * lpUsedDefaultChar); -#define CP_UTF8 65001 +#else +#include +#include +#include #endif int32_t get_num_physical_cores() { @@ -269,6 +265,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.interactive_first = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; + } else if (arg == "--multiline-input") { + params.multiline_input = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -359,6 +357,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); + fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); @@ -479,54 +478,339 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { return lctx; } -/* Keep track of current color of output, and emit ANSI code if it changes. */ -void set_console_color(console_state & con_st, console_color_t color) { - if (con_st.use_color && con_st.color != color) { - switch(color) { - case CONSOLE_COLOR_DEFAULT: - printf(ANSI_COLOR_RESET); - break; - case CONSOLE_COLOR_PROMPT: - printf(ANSI_COLOR_YELLOW); - break; - case CONSOLE_COLOR_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); - break; - } - con_st.color = color; - } -} - -#if defined (_WIN32) -void win32_console_init(bool enable_color) { - unsigned long dwMode = 0; - void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) - if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { - hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) - if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { - hConOut = 0; +void console_init(console_state & con_st) { +#if defined(_WIN32) + // Windows-specific console initialization + DWORD dwMode = 0; + con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) { + con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) { + con_st.hConsole = NULL; } } - if (hConOut) { + if (con_st.hConsole) { // Enable ANSI colors on Windows 10+ - if (enable_color && !(dwMode & 0x4)) { - SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) + if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); } // Set console output codepage to UTF8 SetConsoleOutputCP(CP_UTF8); } - void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) - if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { // Set console input codepage to UTF16 _setmode(_fileno(stdin), _O_WTEXT); + + // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + SetConsoleMode(hConIn, dwMode); + } +#else + // POSIX-specific console initialization + struct termios new_termios; + tcgetattr(STDIN_FILENO, &con_st.prev_state); + new_termios = con_st.prev_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + con_st.tty = fopen("/dev/tty", "w+"); + if (con_st.tty != nullptr) { + con_st.out = con_st.tty; + } +#endif + setlocale(LC_ALL, ""); +} + +void console_cleanup(console_state & con_st) { + // Reset console color + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + +#if !defined(_WIN32) + if (con_st.tty != nullptr) { + con_st.out = stdout; + fclose(con_st.tty); + con_st.tty = nullptr; + } + // Restore the terminal settings on POSIX systems + tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); +#endif +} + +/* Keep track of current color of output, and emit ANSI code if it changes. */ +void console_set_color(console_state & con_st, console_color_t color) { + if (con_st.use_color && con_st.color != color) { + fflush(stdout); + switch(color) { + case CONSOLE_COLOR_DEFAULT: + fprintf(con_st.out, ANSI_COLOR_RESET); + break; + case CONSOLE_COLOR_PROMPT: + fprintf(con_st.out, ANSI_COLOR_YELLOW); + break; + case CONSOLE_COLOR_USER_INPUT: + fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + } + con_st.color = color; + fflush(con_st.out); } } -// Convert a wide Unicode string to an UTF8 string -void win32_utf8_encode(const std::wstring & wstr, std::string & str) { - int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); - std::string strTo(size_needed, 0); - WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), &strTo[0], size_needed, NULL, NULL); - str = strTo; -} +char32_t getchar32() { + wchar_t wc = getwchar(); + if (static_cast(wc) == WEOF) { + return WEOF; + } + +#if WCHAR_MAX == 0xFFFF + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + wchar_t low_surrogate = getwchar(); + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; + } + } + if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + return 0xFFFD; // Return the replacement character U+FFFD + } #endif + + return static_cast(wc); +} + +void pop_cursor(console_state & con_st) { +#if defined(_WIN32) + if (con_st.hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(con_st.hConsole, newCursorPosition); + return; + } +#endif + putc('\b', con_st.out); +} + +int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + return 1; +#else + return wcwidth(codepoint); +#endif +} + +int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // we can trust expectedWidth if we've got one + if (expectedWidth >= 0 || con_st.tty == nullptr) { + fwrite(utf8_codepoint, length, 1, con_st.out); + return expectedWidth; + } + + fputs("\033[6n", con_st.tty); // Query cursor position + int x1, x2, y1, y2; + int results = 0; + results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, con_st.tty); + + fputs("\033[6n", con_st.tty); // Query cursor position + results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif +} + +void replace_last(console_state & con_st, char ch) { +#if defined(_WIN32) + pop_cursor(con_st); + put_codepoint(con_st, &ch, 1, 1); +#else + fprintf(con_st.out, "\b%c", ch); +#endif +} + +void append_utf8(char32_t ch, std::string & out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } +} + +// Helper function to remove the last UTF-8 character from a string +void pop_back_utf8_char(std::string & line) { + if (line.empty()) { + return; + } + + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character + } + line.erase(pos); +} + +bool console_readline(console_state & con_st, std::string & line) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + if (con_st.out != stdout) { + fflush(stdout); + } + + line.clear(); + std::vector widths; + bool is_special_char = false; + bool end_of_stream = false; + + char32_t input_char; + while (true) { + fflush(con_st.out); // Ensure all output is displayed before waiting for input + input_char = getchar32(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == WEOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + replace_last(con_st, line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + char32_t code = getchar32(); + if (code == '[' || code == 0x1B) { + // Discard the rest of the escape sequence + while ((code = getchar32()) != WEOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!widths.empty()) { + int count; + do { + count = widths.back(); + widths.pop_back(); + // Move cursor back, print space, and move cursor back again + for (int i = 0; i < count; i++) { + replace_last(con_st, ' '); + pop_cursor(con_st); + } + pop_back_utf8_char(line); + } while (count == 0 && !widths.empty()); + } + } else { + int offset = line.length(); + append_utf8(input_char, line); + int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } + widths.push_back(width); + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + console_set_color(con_st, CONSOLE_COLOR_PROMPT); + replace_last(con_st, line.back()); + is_special_char = true; + } + } + + bool has_more = con_st.multiline_input; + if (is_special_char) { + replace_last(con_st, ' '); + pop_cursor(con_st); + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + fputc('\n', con_st.out); + has_more = !has_more; + } else { + // llama will just eat the single space, it won't act as a space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + pop_cursor(con_st); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + fputc('\n', con_st.out); + } + } + + fflush(con_st.out); + return has_more; +} diff --git a/examples/common.h b/examples/common.h index 842e1516f..43f1cc9ef 100644 --- a/examples/common.h +++ b/examples/common.h @@ -10,6 +10,11 @@ #include #include +#if !defined (_WIN32) +#include +#include +#endif + // // CLI argument parsing // @@ -56,6 +61,7 @@ struct gpt_params { bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately + bool multiline_input = false; // reverse the usage of `\` bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token @@ -104,13 +110,20 @@ enum console_color_t { }; struct console_state { + bool multiline_input = false; bool use_color = false; console_color_t color = CONSOLE_COLOR_DEFAULT; + + FILE* out = stdout; +#if defined (_WIN32) + void* hConsole; +#else + FILE* tty = nullptr; + termios prev_state; +#endif }; -void set_console_color(console_state & con_st, console_color_t color); - -#if defined (_WIN32) -void win32_console_init(bool enable_color); -void win32_utf8_encode(const std::wstring & wstr, std::string & str); -#endif +void console_init(console_state & con_st); +void console_cleanup(console_state & con_st); +void console_set_color(console_state & con_st, console_color_t color); +bool console_readline(console_state & con_st, std::string & line); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 045093c72..6e1172a48 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -35,12 +35,12 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; } else { + console_cleanup(con_st); + printf("\n"); llama_print_timings(*g_ctx); _exit(130); } @@ -59,10 +59,9 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) con_st.use_color = params.use_color; - -#if defined (_WIN32) - win32_console_init(params.use_color); -#endif + con_st.multiline_input = params.multiline_input; + console_init(con_st); + atexit([]() { console_cleanup(con_st); }); if (params.perplexity) { printf("\n************\n"); @@ -275,12 +274,21 @@ int main(int argc, char ** argv) { std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { + const char *control_message; + if (con_st.multiline_input) { + control_message = " - To return control to LLaMa, end your input with '\\'.\n" + " - To return control without starting a new line, end your input with '/'.\n"; + } else { + control_message = " - Press Return to return control to LLaMa.\n" + " - To return control without starting a new line, end your input with '/'.\n" + " - If you want to submit another line, end your input with '\\'.\n"; + } fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif - " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n\n"); + "%s\n", control_message); + is_interacting = params.interactive_first; } @@ -299,7 +307,7 @@ int main(int argc, char ** argv) { int n_session_consumed = 0; // the first thing we will do is to output the prompt, so set color accordingly - set_console_color(con_st, CONSOLE_COLOR_PROMPT); + console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; @@ -498,7 +506,7 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (input_echo && (int)embd_inp.size() == n_consumed) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } // in interactive mode, and not currently processing queued inputs; @@ -518,17 +526,12 @@ int main(int argc, char ** argv) { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { is_interacting = true; is_antiprompt = true; - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - fflush(stdout); break; } } } if (n_past > 0 && is_interacting) { - // potentially set color to indicate we are taking user input - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - if (params.instruct) { printf("\n> "); } @@ -542,31 +545,12 @@ int main(int argc, char ** argv) { std::string line; bool another_line = true; do { -#if defined(_WIN32) - std::wstring wline; - if (!std::getline(std::wcin, wline)) { - // input stream is bad or EOF received - return 0; - } - win32_utf8_encode(wline, line); -#else - if (!std::getline(std::cin, line)) { - // input stream is bad or EOF received - return 0; - } -#endif - if (!line.empty()) { - if (line.back() == '\\') { - line.pop_back(); // Remove the continue character - } else { - another_line = false; - } - buffer += line + '\n'; // Append the line to the result - } + another_line = console_readline(con_st, line); + buffer += line; } while (another_line); // done taking input, reset color - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back @@ -622,7 +606,5 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - return 0; }