From 799fdc1b5d888b8a8682baf112e1c2a2df0df1c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 3 May 2023 23:24:20 +0300 Subject: [PATCH 01/27] ggml : vectorize Q8_0 quantization https://github.com/ggerganov/ggml/pull/127#issuecomment-1533648531 --- ggml.c | 120 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/ggml.c b/ggml.c index addf0c308..0bcb5f617 100644 --- a/ggml.c +++ b/ggml.c @@ -1509,15 +1509,135 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r } static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { + assert(QK8_0 == 32); assert(k % QK8_0 == 0); + const int nb = k / QK8_0; block_q8_0 * restrict y = vy; +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#else + // scalar quantize_row_q8_0_reference(x, y, k); +#endif } // reference implementation for deterministic creation of model files static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { + assert(QK8_1 == 32); assert(k % QK8_1 == 0); const int nb = k / QK8_1; From f647ce040ff06348d2ceaa5443a6a7a8b80c70c9 Mon Sep 17 00:00:00 2001 From: Tomas Date: Thu, 4 May 2023 17:02:30 +0700 Subject: [PATCH 02/27] fix #1224 reverse prompt and multi line (#1297) * fix reverse prompt and multi line * Code Formatting Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- examples/main/main.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 125c189a3..17a5a90d1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -551,12 +551,14 @@ int main(int argc, char ** argv) { return 0; } #endif - if (line.empty() || line.back() != '\\') { - another_line = false; - } else { - line.pop_back(); // Remove the continue character + if (!line.empty()) { + if (line.back() == '\\') { + line.pop_back(); // Remove the continue character + } else { + another_line = false; + } + buffer += line + '\n'; // Append the line to the result } - buffer += line + '\n'; // Append the line to the result } while (another_line); // done taking input, reset color From c65a7fbfa9c736416a25369cc05d356789df4c15 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Thu, 4 May 2023 03:02:59 -0700 Subject: [PATCH 03/27] Update main's README.md with new features (#1296) --- examples/main/README.md | 141 ++++++++++++++++++++++++++++++++-------- 1 file changed, 113 insertions(+), 28 deletions(-) diff --git a/examples/main/README.md b/examples/main/README.md index ba210d14a..493a8c095 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -17,23 +17,45 @@ This example program allows you to use various LLaMA language models in an easy To get started right away, run the following command, making sure to use the correct path for the model you have: +#### Unix-based systems (Linux, macOS, etc.): + ```bash ./main -m models/7B/ggml-model.bin --prompt "Once upon a time" ``` -The following command generates "infinite" text from a starting prompt (you can use `Ctrl-C` to stop it): +#### Windows: -```bash -./main -m models/7B/ggml-model.bin --ignore-eos --n_predict -1 --keep -1 --prompt "Once upon a time" +```powershell +main.exe -m models\7B\ggml-model.bin --prompt "Once upon a time" ``` For an interactive experience, try this command: +#### Unix-based systems (Linux, macOS, etc.): + ```bash -./main -m models/7B/ggml-model.bin -n -1 --color -r "User:" --in-prefix " " --prompt $'User: Hi\nAI: Hello. I am an AI chatbot. Would you like to talk?\nUser: Sure!\nAI: What would you like to talk about?\nUser:' +./main -m models/7B/ggml-model.bin -n -1 --color -r "User:" --in-prefix " " --prompt 'User: Hi\nAI: Hello. I am an AI chatbot. Would you like to talk?\nUser: Sure!\nAI: What would you like to talk about?\nUser:' ``` -Note that the newline characters in the prompt string above only work on Linux. On Windows, you will have to use the ``--file`` option (see below) to load a multi-line prompt from file instead. +#### Windows: + +```powershell +main.exe -m models\7B\ggml-model.bin -n -1 --color -r "User:" --in-prefix " " --prompt "User: Hi\nAI: Hello. I am an AI chatbot. Would you like to talk?\nUser: Sure!\nAI: What would you like to talk about?\nUser:" +``` + +The following command generates "infinite" text from a starting prompt (you can use `Ctrl-C` to stop it): + +#### Unix-based systems (Linux, macOS, etc.): + +```bash +./main -m models/7B/ggml-model.bin --ignore-eos -n -1 --random-prompt +``` + +#### Windows: + +```powershell +main.exe -m models\7B\ggml-model.bin --ignore-eos -n -1 --random-prompt +``` ## Common Options @@ -42,7 +64,6 @@ In this section, we cover the most commonly used options for running the `main` - `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). - `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses. - `-ins, --instruct`: Run the program in instruction mode, which is particularly useful when working with Alpaca models. -- `-t N, --threads N`: Set the number of threads to use during computation. It is recommended to set this to the number of physical cores your CPU has. - `-n N, --n_predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text. - `-c N, --ctx_size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. @@ -92,7 +113,7 @@ Instruction mode is particularly useful when working with Alpaca models, which a - `-ins, --instruct`: Enable instruction mode to leverage the capabilities of Alpaca models in completing tasks based on user-provided instructions. -Technical detail: the user's input is internally prefixed with the reverse prompt (or ``### Instruction:`` as the default), and followed by ``### Response:`` (except if you just press Return without any input, to keep generating a longer response). +Technical detail: the user's input is internally prefixed with the reverse prompt (or `### Instruction:` as the default), and followed by `### Response:` (except if you just press Return without any input, to keep generating a longer response). By understanding and utilizing these interaction options, you can create engaging and dynamic experiences with the LLaMA models, tailoring the text generation process to your specific needs. @@ -116,7 +137,7 @@ By utilizing context management options like `--ctx_size` and `--keep`, you can ## Generation Flags -The following options are related to controlling the text generation process, influencing the diversity, creativity, and quality of the generated text. Understanding these options will help you fine-tune the output according to your needs: +The following options allow you to control the text generation process and fine-tune the diversity, creativity, and quality of the generated text according to your needs. By adjusting these options and experimenting with different combinations of values, you can find the best settings for your specific use case. ### Number of Tokens to Predict @@ -124,13 +145,7 @@ The following options are related to controlling the text generation process, in The `--n_predict` option controls the number of tokens the model generates in response to the input prompt. By adjusting this value, you can influence the length of the generated text. A higher value will result in longer text, while a lower value will produce shorter text. A value of -1 will cause text to be generated without limit. -It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `n_predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the ``--ignore-eos`` parameter. - -### RNG Seed - -- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1). - -The RNG seed is used to initialize the random number generator that influences the text generation process. By setting a specific seed value, you can obtain consistent and reproducible results across multiple runs with the same input and settings. This can be helpful for testing, debugging, or comparing the effects of different options on the generated text to see when they diverge. If the seed is set to a value less than 0, a random seed will be used, which will result in different outputs on each run. +It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `n_predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter. ### Temperature @@ -138,15 +153,21 @@ The RNG seed is used to initialize the random number generator that influences t Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run. -Example usage: `--temp 0.8` +Example usage: `--temp 0.5` ### Repeat Penalty - `--repeat_penalty N`: Control the repetition of token sequences in the generated text (default: 1.1). +- `--repeat_last_n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx_size). +- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty. -Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.1. +The `repeat_penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.1. -Example usage: `--repeat_penalty 1.1` +The `repeat_last_n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx_size`). + +Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases. + +Example usage: `--repeat_penalty 1.15 --repeat_last_n 128 --no-penalize-nl` ### Top-K Sampling @@ -154,7 +175,7 @@ Example usage: `--repeat_penalty 1.1` Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text. The default value is 40. -Example usage: `--top_k 40` +Example usage: `--top_k 30` ### Top-P Sampling @@ -162,23 +183,87 @@ Example usage: `--top_k 40` Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. The default value is 0.9. -Example usage: `--top_p 0.9` +Example usage: `--top_p 0.95` -By adjusting these options, you can control the diversity, quality, and creativity of the generated text to better suit your needs. You can experiment with different combinations of values to find the best settings for your specific use case. +### Tail Free Sampling (TFS) + +- `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled). + +Tail free sampling (TFS) is a text generation technique that aims to reduce the impact of less likely tokens, which may be less relevant, less coherent, or nonsensical, on the output. The method adjusts the logits (token probabilities) by raising them to the power of the parameter z. A higher value of z (e.g., 2.0) will further suppress less likely tokens from the tail of the distribution, while a value of 1.0 disables the effect of TFS. By setting the parameter z, you can control how much the probabilities of less likely tokens are reduced. + +Example usage: `--tfs 2.0` + +### Locally Typical Sampling + +- `--typical N`: Enable locally typical sampling with parameter p (default: 1.0, 1.0 = disabled). + +Locally typical sampling promotes the generation of contextually coherent and diverse text by sampling tokens that are typical or expected based on the surrounding context. By setting the parameter p between 0 and 1, you can control the balance between producing text that is locally coherent and diverse. A value closer to 1 will promote more contextually coherent tokens, while a value closer to 0 will promote more diverse tokens. A value equal to 1 disables locally typical sampling. + +Example usage: `--typical 0.9` + +### Mirostat Sampling + +- `--mirostat N`: Enable Mirostat sampling, controlling perplexity during text generation (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0). +- `--mirostat_lr N`: Set the Mirostat learning rate, parameter eta (default: 0.1). +- `--mirostat_ent N`: Set the Mirostat target entropy, parameter tau (default: 5.0). + +Mirostat is an algorithm that actively maintains the quality of generated text within a desired range during text generation. It aims to strike a balance between coherence and diversity, avoiding low-quality output caused by excessive repetition (boredom traps) or incoherence (confusion traps). + +The `--mirostat_lr` option sets the Mirostat learning rate (eta). The learning rate influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. The default value is `0.1`. + +The `--mirostat_ent` option sets the Mirostat target entropy (tau), which represents the desired perplexity value for the generated text. Adjusting the target entropy allows you to control the balance between coherence and diversity in the generated text. A lower value will result in more focused and coherent text, while a higher value will lead to more diverse and potentially less coherent text. The default value is `5.0`. + +Example usage: `--mirostat 2 --mirostat_lr 0.05 --mirostat_ent 3.0` + +### Logit Bias + +- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion. + +The logit bias option allows you to manually adjust the likelihood of specific tokens appearing in the generated text. By providing a token ID and a positive or negative bias value, you can increase or decrease the probability of that token being generated. + +For example, use `--logit-bias 15043+1` to increase the likelihood of the token 'Hello', or `--logit-bias 15043-1` to decrease its likelihood. Using a value of negative infinity, `--logit-bias 15043-inf` ensures that the token `Hello` is never produced. + +A more practical use case might be to prevent the generation of `\code{begin}` and `\code{end}` by setting the `\` token (29905) to negative infinity with `-l 29905-inf`. (This is due to the prevalence of LaTeX codes that show up in LLaMA model inference.) + +Example usage: `--logit-bias 29905-inf` + +### RNG Seed + +- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, < 0 = random seed). + +The RNG seed is used to initialize the random number generator that influences the text generation process. By setting a specific seed value, you can obtain consistent and reproducible results across multiple runs with the same input and settings. This can be helpful for testing, debugging, or comparing the effects of different options on the generated text to see when they diverge. If the seed is set to a value less than 0, a random seed will be used, which will result in different outputs on each run. ## Performance Tuning and Memory Options -These options help improve the performance and memory usage of the LLaMA models: +These options help improve the performance and memory usage of the LLaMA models. By adjusting these settings, you can fine-tune the model's behavior to better suit your system's capabilities and achieve optimal performance for your specific use case. + +### Number of Threads + +- `-t N, --threads N`: Set the number of threads to use during computation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance. + +### Mlock + +- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped. This can improve performance but trades away some of the advantages of memory-mapping by requiring more RAM to run and potentially slowing down load times as the model loads into RAM. + +### No Memory Mapping + +- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. However, if the model is larger than your total amount of RAM or if your system is low on available memory, using mmap might increase the risk of pageouts, negatively impacting performance. Disabling mmap results in slower load times but may reduce pageouts if you're not using `--mlock`. Note that if the model is larger than the total amount of RAM, turning off mmap would prevent the model from loading at all. + +### Memory Float 32 + +- `--memory_f32`: Use 32-bit floats instead of 16-bit floats for memory key+value, allowing higher quality inference at the cost of higher memory usage. + +### Batch Size -- `-t N, --threads N`: Set the number of threads to use during computation. Using the correct number of threads can greatly improve performance. It is recommended to set this value to the number of CPU cores. -- `--mlock`: Lock the model in memory, preventing it from being swapped out when mmaped. This can improve performance. -- `--no-mmap`: Do not memory-map the model. This results in a slower load time but may reduce pageouts if you're not using `mlock`. -- `--memory_f32`: Use 32 bit floats instead of 16 bit floats for memory key+value, allowing higher quality inference at the cost of memory. - `-b N, --batch_size N`: Set the batch size for prompt processing (default: 512). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations. -For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-data--run). +### Session Caching -By understanding and using these performance tuning settings, you can optimize the LLaMA model's behavior to achieve the best performance for your specific needs. +- `--session FNAME`: Specify a file to load/save the session, which caches the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The session file is created during the first run and is reused in subsequent runs. If you change your prompt such that 75% or less of the session is reusable, the existing session file will be overwritten with a new, updated version to maintain optimal performance. + +### Quantization + +For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-data--run). ## Additional Options From db1080876a62ec3bb4119d90b16e7dce7594b733 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Thu, 4 May 2023 05:08:25 -0700 Subject: [PATCH 04/27] Only escape prompts when used with `-e` (#1311) --- examples/common.cpp | 46 ++++++++++++++++++++++------------------- examples/main/README.md | 9 ++++++-- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 1a2f4743a..cd6300041 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -66,35 +66,33 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -std::string process_escapes(const char* input) { - std::string output; +void process_escapes(std::string& input) { + std::size_t input_len = input.length(); + std::size_t output_idx = 0; - if (input != nullptr) { - std::size_t input_len = std::strlen(input); - output.reserve(input_len); - - for (std::size_t i = 0; i < input_len; ++i) { - if (input[i] == '\\' && i + 1 < input_len) { - switch (input[++i]) { - case 'n': output.push_back('\n'); break; - case 't': output.push_back('\t'); break; - case '\'': output.push_back('\''); break; - case '\"': output.push_back('\"'); break; - case '\\': output.push_back('\\'); break; - default: output.push_back('\\'); - output.push_back(input[i]); break; - } - } else { - output.push_back(input[i]); + for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { + if (input[input_idx] == '\\' && input_idx + 1 < input_len) { + switch (input[++input_idx]) { + case 'n': input[output_idx++] = '\n'; break; + case 'r': input[output_idx++] = '\r'; break; + case 't': input[output_idx++] = '\t'; break; + case '\'': input[output_idx++] = '\''; break; + case '\"': input[output_idx++] = '\"'; break; + case '\\': input[output_idx++] = '\\'; break; + default: input[output_idx++] = '\\'; + input[output_idx++] = input[input_idx]; break; } + } else { + input[output_idx++] = input[input_idx]; } } - return output; + input.resize(output_idx); } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; + bool escape_prompt = false; std::string arg; gpt_params default_params; @@ -118,7 +116,9 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.prompt = process_escapes(argv[i]); + params.prompt = argv[i]; + } else if (arg == "-e") { + escape_prompt = true; } else if (arg == "--session") { if (++i >= argc) { invalid_param = true; @@ -335,6 +335,9 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { gpt_print_usage(argc, argv, default_params); exit(1); } + if (escape_prompt) { + process_escapes(params.prompt); + } return true; } @@ -355,6 +358,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); fprintf(stderr, " prompt to start generation with (default: empty)\n"); + fprintf(stderr, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n"); fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); diff --git a/examples/main/README.md b/examples/main/README.md index 493a8c095..6b7facb3b 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -34,13 +34,18 @@ For an interactive experience, try this command: #### Unix-based systems (Linux, macOS, etc.): ```bash -./main -m models/7B/ggml-model.bin -n -1 --color -r "User:" --in-prefix " " --prompt 'User: Hi\nAI: Hello. I am an AI chatbot. Would you like to talk?\nUser: Sure!\nAI: What would you like to talk about?\nUser:' +./main -m models/7B/ggml-model.bin -n -1 --color -r "User:" --in-prefix " " \ +'User: Hi +AI: Hello. I am an AI chatbot. Would you like to talk? +User: Sure! +AI: What would you like to talk about? +User:' ``` #### Windows: ```powershell -main.exe -m models\7B\ggml-model.bin -n -1 --color -r "User:" --in-prefix " " --prompt "User: Hi\nAI: Hello. I am an AI chatbot. Would you like to talk?\nUser: Sure!\nAI: What would you like to talk about?\nUser:" +main.exe -m models\7B\ggml-model.bin -n -1 --color -r "User:" --in-prefix " " -e --prompt "User: Hi\nAI: Hello. I am an AI chatbot. Would you like to talk?\nUser: Sure!\nAI: What would you like to talk about?\nUser:" ``` The following command generates "infinite" text from a starting prompt (you can use `Ctrl-C` to stop it): From 20fbf2a2a08d8edefe9b3435fa86f8b2f63f8588 Mon Sep 17 00:00:00 2001 From: Ron Jailall Date: Thu, 4 May 2023 11:05:59 -0400 Subject: [PATCH 05/27] ggml : change immintrin.h to intrin.h for compatibility (#1307) * change immintrin.h to intrin.h for compatibility Building on windows11 arm throws an error on this line. Seems like using intrin.h covers x86 and and arm * conditional def of intrin.h * fix typo in ggml.c --- ggml.c | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml.c b/ggml.c index 0bcb5f617..4d49242a4 100644 --- a/ggml.c +++ b/ggml.c @@ -180,9 +180,13 @@ typedef double ggml_float; #undef bool #define bool _Bool #else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else #include #endif #endif +#endif #ifdef __F16C__ From 2edbdb0f99336cb41f0995061c7602ed54beb863 Mon Sep 17 00:00:00 2001 From: 44670 <44670@users.noreply.github.com> Date: Thu, 4 May 2023 23:41:12 +0800 Subject: [PATCH 06/27] main : add --in-suffix option (#1318) * adding --in-suffix option * print input suffix before generation --- examples/common.cpp | 7 +++++++ examples/common.h | 1 + examples/main/README.md | 8 ++++++++ examples/main/main.cpp | 9 +++++++++ 4 files changed, 25 insertions(+) diff --git a/examples/common.cpp b/examples/common.cpp index cd6300041..97eded6ec 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -324,6 +324,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.input_prefix = argv[i]; + } else if (arg == "--in-suffix") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.input_suffix = argv[i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); @@ -362,6 +368,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n"); fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); + fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); diff --git a/examples/common.h b/examples/common.h index 138d0ded0..842e1516f 100644 --- a/examples/common.h +++ b/examples/common.h @@ -43,6 +43,7 @@ struct gpt_params { std::string prompt = ""; std::string path_session = ""; // path to file for saving/loading model eval state std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/main/README.md b/examples/main/README.md index 6b7facb3b..35f87bcd5 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -112,6 +112,14 @@ The `--in-prefix` flag is used to add a prefix to your input, primarily, this is ./main -r "User:" --in-prefix " " ``` +### In-Suffix + +The `--in-suffix` flag is used to add a suffix after your input. This is useful for adding an "Assistant:" prompt after the user's input. It's added after the new-line character (`\n`) that's automatically added to the end of the user's input. Here's an example of how to use the `--in-suffix` flag in conjunction with the `--reverse-prompt` flag: + +```sh +./main -r "User:" --in-prefix " " --in-suffix "Assistant:" +``` + ### Instruction Mode Instruction mode is particularly useful when working with Alpaca models, which are designed to follow user instructions for specific tasks: diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 17a5a90d1..43dca8eb5 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -260,6 +260,10 @@ int main(int argc, char ** argv) { if (!params.input_prefix.empty()) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } + + if (!params.input_suffix.empty()) { + fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str()); + } } fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); @@ -567,6 +571,11 @@ int main(int argc, char ** argv) { // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back if (buffer.length() > 1) { + // append input suffix if any + if (!params.input_suffix.empty()) { + buffer += params.input_suffix; + printf("%s", params.input_suffix.c_str()); + } // instruct mode: insert instruction prefix if (params.instruct && !is_antiprompt) { From 360cfe5bec852805b84eec799102fc6f45df9fef Mon Sep 17 00:00:00 2001 From: 44670 <44670@users.noreply.github.com> Date: Fri, 5 May 2023 00:33:31 +0800 Subject: [PATCH 07/27] readme : add OpenBuddy link (#1321) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0002f8cc1..f1fa63542 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ as the main playground for developing new features for the [ggml](https://github - [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne) - [X] [Vicuna](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5533894) - [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/) +- [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy) **Bindings:** From d3e8093e9b5845514b049ede3b12728c8f013eba Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Thu, 4 May 2023 19:54:37 +0300 Subject: [PATCH 08/27] convert: support DT_BF16 tensors (#1309) Co-authored-by: Pavol Rusnak --- convert.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/convert.py b/convert.py index 7f7ae05fa..c817a343e 100644 --- a/convert.py +++ b/convert.py @@ -67,6 +67,7 @@ FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \ {ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()} DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { + DT_BF16: np.dtype(np.uint16), DT_F16: np.dtype(np.float16), DT_F32: np.dtype(np.float32), DT_I32: np.dtype(np.int32), @@ -276,6 +277,12 @@ class Tensor(metaclass=ABCMeta): def to_ggml(self) -> 'GGMLCompatibleTensor': ... +def bf16_to_fp32(bf16_arr: np.ndarray) -> np.ndarray: + assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" + fp32_arr = bf16_arr.astype(np.uint32) << 16 + return fp32_arr.view(np.float32) + + class UnquantizedTensor(Tensor): def __init__(self, ndarray: NDArray) -> None: assert isinstance(ndarray, np.ndarray) @@ -284,6 +291,8 @@ class UnquantizedTensor(Tensor): def astype(self, data_type: DataType) -> Tensor: dtype = DATA_TYPE_TO_NUMPY[data_type] + if self.data_type == DT_BF16: + self.ndarray = bf16_to_fp32(self.ndarray) return UnquantizedTensor(self.ndarray.astype(dtype)) def to_ggml(self) -> 'UnquantizedTensor': @@ -686,6 +695,7 @@ class LazyUnpickler(pickle.Unpickler): description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' return LazyStorage(load=load, kind=pid[1], description=description) + @staticmethod def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, # pyright: ignore[reportSelfClsParameterName] requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: assert isinstance(storage, LazyStorage) @@ -696,12 +706,18 @@ class LazyUnpickler(pickle.Unpickler): description = f'pickled storage_offset={storage_offset} in {storage.description}' return LazyTensor(load, list(size), storage.kind.data_type, description) + @staticmethod + def rebuild_from_type_v2(func, new_type, args, state): + return func(*args) + CLASSES: Dict[Any, Any] = { + ('torch._tensor', '_rebuild_from_type_v2'): rebuild_from_type_v2, ('torch._utils', '_rebuild_tensor_v2'): lazy_rebuild_tensor_v2, ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), ('torch', 'IntStorage'): LazyStorageKind(DT_I32), + ('torch', 'Tensor'): LazyTensor, } def find_class(self, module: str, name: str) -> Any: @@ -961,7 +977,7 @@ class OutputFile: def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType: wq_type = model["layers.0.attention.wq.weight"].data_type - if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): + if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): return GGMLFileType.AllF32 if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): return GGMLFileType.MostlyF16 From 34d9f22f44c42d345cc72c8f3aa4cb71c5df0acb Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Thu, 4 May 2023 19:56:27 +0300 Subject: [PATCH 09/27] Wrap exceptions in std::exception to verbose output on exception. (#1316) --- llama-util.h | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/llama-util.h b/llama-util.h index d531588d5..88ec28dca 100644 --- a/llama-util.h +++ b/llama-util.h @@ -14,6 +14,7 @@ #include #include +#include #ifdef __has_include #if __has_include() @@ -74,7 +75,7 @@ struct llama_file { llama_file(const char * fname, const char * mode) { fp = std::fopen(fname, mode); if (fp == NULL) { - throw format("failed to open %s: %s", fname, std::strerror(errno)); + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); } seek(0, SEEK_END); size = tell(); @@ -107,10 +108,10 @@ struct llama_file { errno = 0; std::size_t ret = std::fread(ptr, size, 1, fp); if (ferror(fp)) { - throw format("read error: %s", strerror(errno)); + throw std::runtime_error(format("read error: %s", strerror(errno))); } if (ret != 1) { - throw std::string("unexpectedly reached end of file"); + throw std::runtime_error(std::string("unexpectedly reached end of file")); } } @@ -133,7 +134,7 @@ struct llama_file { errno = 0; size_t ret = std::fwrite(ptr, size, 1, fp); if (ret != 1) { - throw format("write error: %s", strerror(errno)); + throw std::runtime_error(format("write error: %s", strerror(errno))); } } @@ -180,7 +181,7 @@ struct llama_mmap { #endif addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); if (addr == MAP_FAILED) { - throw format("mmap failed: %s", strerror(errno)); + throw std::runtime_error(format("mmap failed: %s", strerror(errno))); } if (prefetch) { @@ -207,7 +208,7 @@ struct llama_mmap { DWORD error = GetLastError(); if (hMapping == NULL) { - throw format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()); + throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); } addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); @@ -215,7 +216,7 @@ struct llama_mmap { CloseHandle(hMapping); if (addr == NULL) { - throw format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()); + throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); } #if _WIN32_WINNT >= _WIN32_WINNT_WIN8 @@ -245,7 +246,7 @@ struct llama_mmap { llama_mmap(struct llama_file *, bool prefetch = true) { (void)prefetch; - throw std::string("mmap not supported"); + throw std::runtime_error(std::string("mmap not supported")); } #endif }; From 94c5652fc0f4d04ac54412c4d81e2ebcdafb6ede Mon Sep 17 00:00:00 2001 From: slaren Date: Fri, 5 May 2023 00:58:56 +0200 Subject: [PATCH 10/27] quantize: make output filename optional, default to ggml-model-.bin (#1301) --- examples/quantize/quantize.cpp | 100 ++++++++++++++++++++++++++------- 1 file changed, 81 insertions(+), 19 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 198bd5fcb..7c77018da 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -6,23 +6,47 @@ #include #include -static const std::map LLAMA_FTYPE_MAP = { - {"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0}, - {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1}, - {"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2}, - {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0}, - {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1}, - {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, +static const std::map LLAMA_FTYPE_MAP = { + {"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0}, + {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1}, + {"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2}, + {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0}, + {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1}, + {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, }; +bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) { + auto it = LLAMA_FTYPE_MAP.find(ftype_str); + if (it != LLAMA_FTYPE_MAP.end()) { + ftype = it->second; + ftype_str_out = it->first; + return true; + } + // try to parse as an integer + try { + int ftype_int = std::stoi(ftype_str); + for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { + if (it->second == ftype_int) { + ftype = it->second; + ftype_str_out = it->first; + return true; + } + } + } + catch (...) { + // stoi failed + } + return false; +} + // usage: -// ./quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type +// ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads] // int main(int argc, char ** argv) { ggml_time_init(); - if (argc < 4) { - fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type [nthread]\n", argv[0]); + if (argc < 3) { + fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]); for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second); } @@ -36,24 +60,62 @@ int main(int argc, char ** argv) { ggml_free(ctx); } + // parse command line arguments const std::string fname_inp = argv[1]; - const std::string fname_out = argv[2]; + std::string fname_out; + int nthread; + llama_ftype ftype; - enum llama_ftype ftype; - if (argv[3][0] == 'q') { - auto it = LLAMA_FTYPE_MAP.find(argv[3]); - if (it == LLAMA_FTYPE_MAP.end()) { - fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, argv[3]); + int arg_idx = 2; + std::string ftype_str; + if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) { + // argv[2] is the ftype + std::string fpath; + const size_t pos = fname_inp.find_last_of('/'); + if (pos != std::string::npos) { + fpath = fname_inp.substr(0, pos + 1); + } + // export as [inp path]/ggml-model-[ftype].bin + fname_out = fpath + "ggml-model-" + ftype_str + ".bin"; + arg_idx++; + } + else { + // argv[2] is the output path + fname_out = argv[arg_idx]; + arg_idx++; + + if (argc <= arg_idx) { + fprintf(stderr, "%s: missing ftype\n", __func__); + return 1; + } + // argv[3] is the ftype + if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) { + fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]); + return 1; + } + arg_idx++; + } + + // parse nthreads + if (argc > arg_idx) { + try { + nthread = std::stoi(argv[arg_idx]); + } + catch (const std::exception & e) { + fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what()); return 1; } - ftype = it->second; } else { - ftype = (enum llama_ftype)atoi(argv[3]); + nthread = 0; } fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); - int nthread = argc > 4 ? atoi(argv[4]) : 0; + fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str()); + if (nthread > 0) { + fprintf(stderr, " using %d threads", nthread); + } + fprintf(stderr, "\n"); const int64_t t_main_start_us = ggml_time_us(); From a90e96b266873ebb5e947c9864b12193bdada0fb Mon Sep 17 00:00:00 2001 From: Benjamin Lecaillon <84293038+blecaillon@users.noreply.github.com> Date: Fri, 5 May 2023 02:17:07 +0200 Subject: [PATCH 11/27] Convert.py @staticmethod (#1327) * Line 698 has one #staticmethod and should not otherwise throw error at unpickle.load() as not callable * Update convert.py --------- Co-authored-by: Ivan Stepanov --- convert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert.py b/convert.py index c817a343e..126beaabc 100644 --- a/convert.py +++ b/convert.py @@ -695,7 +695,7 @@ class LazyUnpickler(pickle.Unpickler): description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' return LazyStorage(load=load, kind=pid[1], description=description) - @staticmethod + # @staticmethod def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, # pyright: ignore[reportSelfClsParameterName] requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: assert isinstance(storage, LazyStorage) @@ -706,7 +706,7 @@ class LazyUnpickler(pickle.Unpickler): description = f'pickled storage_offset={storage_offset} in {storage.description}' return LazyTensor(load, list(size), storage.kind.data_type, description) - @staticmethod + # @staticmethod def rebuild_from_type_v2(func, new_type, args, state): return func(*args) From 2d13786e91ec9fd28ddf737053822042a824da78 Mon Sep 17 00:00:00 2001 From: Ionoclast Laboratories Date: Fri, 5 May 2023 08:18:21 -0400 Subject: [PATCH 12/27] Fix for OpenCL / clbast builds on macOS. (#1329) --- Makefile | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 94acefdde..260b2487f 100644 --- a/Makefile +++ b/Makefile @@ -121,7 +121,12 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h endif ifdef LLAMA_CLBLAST CFLAGS += -DGGML_USE_CLBLAST - LDFLAGS += -lclblast -lOpenCL + # Mac provides OpenCL as a framework + ifeq ($(UNAME_S),Darwin) + LDFLAGS += -lclblast -framework OpenCL + else + LDFLAGS += -lclblast -lOpenCL + endif OBJS += ggml-opencl.o ggml-opencl.o: ggml-opencl.c ggml-opencl.h $(CC) $(CFLAGS) -c $< -o $@ From 921dcee00a55d9aba3b3026d0509d31ac8386e2a Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Fri, 5 May 2023 16:43:36 +0200 Subject: [PATCH 13/27] readme: add missing info (#1324) --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f1fa63542..233c5c5e1 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,12 @@ The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quant - Plain C/C++ implementation without dependencies - Apple silicon first-class citizen - optimized via ARM NEON and Accelerate framework -- AVX2 support for x86 architectures +- AVX, AVX2 and AVX512 support for x86 architectures - Mixed F16 / F32 precision -- 4-bit integer quantization support +- 4-bit, 5-bit and 8-bit integer quantization support - Runs on the CPU +- OpenBLAS support +- cuBLAS and CLBlast support The original implementation of `llama.cpp` was [hacked in an evening](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022). Since then, the project has improved significantly thanks to many contributions. This project is for educational purposes and serves From a3b85b28da84c67c3406807aef5e0457bcc4b00f Mon Sep 17 00:00:00 2001 From: Erik Scholz Date: Fri, 5 May 2023 22:56:09 +0200 Subject: [PATCH 14/27] ci : add cublas to windows release (#1271) --- .github/workflows/build.yml | 77 +++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 179080576..18bb33f94 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -210,6 +210,82 @@ jobs: path: | llama-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip + windows-latest-cmake-cublas: + runs-on: windows-latest + + strategy: + matrix: + cuda: ['12.1.0', '11.7.1'] + build: ['cublas'] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v1 + + - uses: Jimver/cuda-toolkit@v0.2.10 + id: cuda-toolkit + with: + cuda: ${{ matrix.cuda }} + # TODO(green-sky): _dev seems to fail, and non dev are not enought + #sub-packages: '["nvcc", "cudart", "cublas", "cudart_dev", "cublas_dev"]' + + - name: Build + id: cmake_build + run: | + mkdir build + cd build + cmake .. -DLLAMA_CUBLAS=ON + cmake --build . --config Release + + - name: Get commit hash + id: commit + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + uses: pr-mpt/actions-commit-hash@v2 + + - name: Pack artifacts + id: pack_artifacts + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + run: | + 7z a llama-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-cu${{ matrix.cuda }}-x64.zip .\build\bin\Release\* + + - name: Upload artifacts + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + uses: actions/upload-artifact@v3 + with: + path: | + llama-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-cu${{ matrix.cuda }}-x64.zip + + - name: Copy and pack Cuda runtime + if: ${{ matrix.cuda == '12.1.0' }} + # TODO(green-sky): paths are cuda 12 specific + run: | + echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}" + mkdir '.\build\bin\cudart\' + cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cudart64_12.dll" '.\build\bin\cudart\' + cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cublas64_12.dll" '.\build\bin\cudart\' + cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cublasLt64_12.dll" '.\build\bin\cudart\' + 7z a cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip .\build\bin\cudart\* + + - name: Copy and pack Cuda runtime + if: ${{ matrix.cuda == '11.7.1' }} + # TODO(green-sky): paths are cuda 11 specific + run: | + echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}" + mkdir '.\build\bin\cudart\' + ls "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin" + cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cudart64_110.dll" '.\build\bin\cudart\' + cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cublas64_11.dll" '.\build\bin\cudart\' + cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cublasLt64_11.dll" '.\build\bin\cudart\' + 7z a cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip .\build\bin\cudart\* + + - name: Upload Cuda runtime + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + uses: actions/upload-artifact@v3 + with: + path: | + cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip + release: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} @@ -221,6 +297,7 @@ jobs: - macOS-latest-make - macOS-latest-cmake - windows-latest-cmake + - windows-latest-cmake-cublas steps: - name: Download artifacts From 173d0e6419e8f8f3c1f4f13201b777f4c60629f3 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Fri, 5 May 2023 23:57:14 +0200 Subject: [PATCH 15/27] makefile: automatic Arch Linux detection (#1332) This commit is a port of a detection method used in koboldcpp's Makefile in order to automatically set the -lcblas option on Arch Linux --- Makefile | 6 +++++- README.md | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 260b2487f..0ddff9961 100644 --- a/Makefile +++ b/Makefile @@ -107,7 +107,11 @@ ifndef LLAMA_NO_ACCELERATE endif ifdef LLAMA_OPENBLAS CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas - LDFLAGS += -lopenblas + ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/null),) + LDFLAGS += -lopenblas -lcblas + else + LDFLAGS += -lopenblas + endif endif ifdef LLAMA_CUBLAS CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include diff --git a/README.md b/README.md index 233c5c5e1..19cc94aa2 100644 --- a/README.md +++ b/README.md @@ -216,7 +216,6 @@ Building the program with BLAS support may lead to some performance improvements ```bash make LLAMA_OPENBLAS=1 ``` - Note: In order to build on Arch Linux with OpenBLAS support enabled you must edit the Makefile adding at the end of the line 105: `-lcblas` - On Windows: From 3924088512d9e12e90ed6dbf28a6c5712481d33e Mon Sep 17 00:00:00 2001 From: Jed Fox Date: Sat, 6 May 2023 17:01:47 -0400 Subject: [PATCH 16/27] Remove default arguments from sampling functions (#1343) --- .gitignore | 1 + examples/main/main.cpp | 8 ++++---- llama.cpp | 2 +- llama.h | 8 ++++---- tests/test-sampling.cpp | 8 ++++---- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index e479c6180..6f275fea4 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ build-sanitize-addr/ build-sanitize-thread/ models/* +*.bin /main /quantize diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 43dca8eb5..5ac151e14 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -444,10 +444,10 @@ int main(int argc, char ** argv) { id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling - llama_sample_top_k(ctx, &candidates_p, top_k); - llama_sample_tail_free(ctx, &candidates_p, tfs_z); - llama_sample_typical(ctx, &candidates_p, typical_p); - llama_sample_top_p(ctx, &candidates_p, top_p); + llama_sample_top_k(ctx, &candidates_p, top_k, 1); + llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); + llama_sample_typical(ctx, &candidates_p, typical_p, 1); + llama_sample_top_p(ctx, &candidates_p, top_p, 1); llama_sample_temperature(ctx, &candidates_p, temp); id = llama_sample_token(ctx, &candidates_p); } diff --git a/llama.cpp b/llama.cpp index 85af4dc49..c36c6ced6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1791,7 +1791,7 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sample_top_k(nullptr, candidates, int(k)); + llama_sample_top_k(nullptr, candidates, int(k), 1); if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } diff --git a/llama.h b/llama.h index e993c464a..58c6e0699 100644 --- a/llama.h +++ b/llama.h @@ -202,16 +202,16 @@ extern "C" { LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1); + LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1); + LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 8ce59af3d..9174c1e37 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -32,7 +32,7 @@ void test_top_k(const std::vector & probs, llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); DUMP(&candidates_p); - llama_sample_top_k(nullptr, &candidates_p, k); + llama_sample_top_k(nullptr, &candidates_p, k, 1); DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); @@ -57,7 +57,7 @@ void test_top_p(const std::vector & probs, llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); DUMP(&candidates_p); - llama_sample_top_p(nullptr, &candidates_p, p); + llama_sample_top_p(nullptr, &candidates_p, p, 1); DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); @@ -80,7 +80,7 @@ void test_tfs(const std::vector & probs, llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sample_tail_free(nullptr, &candidates_p, z); + llama_sample_tail_free(nullptr, &candidates_p, z, 1); DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); @@ -103,7 +103,7 @@ void test_typical(const std::vector & probs, llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sample_typical(nullptr, &candidates_p, p); + llama_sample_typical(nullptr, &candidates_p, p, 1); DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); From 1b0fd454650ef4d68a980e3225488b79e6e9af25 Mon Sep 17 00:00:00 2001 From: swittk Date: Sun, 7 May 2023 10:03:23 +0700 Subject: [PATCH 17/27] ggml : Allow usage of CLBlast alongside Accelerate.framework (#1336) Minor edit in ggml.c which originally would prevent OpenCL from loading completely if GGML_USE_ACCELERATE was defined. Minor speedup in prompt eval time. --- ggml.c | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml.c b/ggml.c index 4d49242a4..1b89bdd89 100644 --- a/ggml.c +++ b/ggml.c @@ -137,6 +137,9 @@ inline static void* ggml_aligned_malloc(size_t size) { #if defined(GGML_USE_ACCELERATE) #include +#if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions +#include "ggml-opencl.h" +#endif #elif defined(GGML_USE_OPENBLAS) #include #elif defined(GGML_USE_CUBLAS) From e1295513a48ae8254d8af5ec0250b56d6eaffefd Mon Sep 17 00:00:00 2001 From: Henri Vasserman Date: Sun, 7 May 2023 14:20:09 +0300 Subject: [PATCH 18/27] CI: add Windows CLBlast and OpenBLAS builds (#1277) * Add OpenCL and CLBlast support * Add OpenBLAS support * Remove testing from matrix * change build name to 'clblast' --- .github/workflows/build.yml | 73 +++++++++++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 18bb33f94..a5938bf93 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -120,7 +120,7 @@ jobs: make macOS-latest-cmake: - runs-on: macOS-latest + runs-on: macos-latest steps: - name: Clone @@ -148,22 +148,64 @@ jobs: windows-latest-cmake: runs-on: windows-latest + env: + OPENBLAS_VERSION: 0.3.23 + OPENCL_VERSION: 2023.04.17 + CLBLAST_VERSION: 1.5.3 strategy: matrix: include: - - build: 'avx2' - defines: '' - - build: 'avx' - defines: '-DLLAMA_AVX2=OFF' - - build: 'avx512' - defines: '-DLLAMA_AVX512=ON -DBUILD_SHARED_LIBS=ON' + - build: 'avx2' + defines: '' + - build: 'avx' + defines: '-DLLAMA_AVX2=OFF' + - build: 'avx512' + defines: '-DLLAMA_AVX512=ON -DBUILD_SHARED_LIBS=ON' + - build: 'clblast' + defines: '-DLLAMA_CLBLAST=ON -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/clblast"' + - build: 'openblas' + defines: '-DLLAMA_OPENBLAS=ON -DBLAS_LIBRARIES="/LIBPATH:$env:RUNNER_TEMP/openblas/lib" -DOPENBLAS_INC="$env:RUNNER_TEMP/openblas/include"' steps: - name: Clone id: checkout uses: actions/checkout@v1 + - name: Download OpenCL SDK + id: get_opencl + if: ${{ matrix.build == 'clblast' }} + run: | + curl.exe -o $env:RUNNER_TEMP/opencl.zip -L "https://github.com/KhronosGroup/OpenCL-SDK/releases/download/v${env:OPENCL_VERSION}/OpenCL-SDK-v${env:OPENCL_VERSION}-Win-x64.zip" + mkdir $env:RUNNER_TEMP/opencl + tar.exe -xvf $env:RUNNER_TEMP/opencl.zip --strip-components=1 -C $env:RUNNER_TEMP/opencl + + - name: Download CLBlast + id: get_clblast + if: ${{ matrix.build == 'clblast' }} + run: | + curl.exe -o $env:RUNNER_TEMP/clblast.zip -L "https://github.com/CNugteren/CLBlast/releases/download/${env:CLBLAST_VERSION}/CLBlast-${env:CLBLAST_VERSION}-Windows-x64.zip" + curl.exe -o $env:RUNNER_TEMP/CLBlast.LICENSE.txt -L "https://github.com/CNugteren/CLBlast/raw/${env:CLBLAST_VERSION}/LICENSE" + mkdir $env:RUNNER_TEMP/clblast + tar.exe -xvf $env:RUNNER_TEMP/clblast.zip -C $env:RUNNER_TEMP/clblast + foreach ($f in (gci -Recurse -Path "$env:RUNNER_TEMP/clblast" -Filter '*.cmake')) { + $txt = Get-Content -Path $f -Raw + $txt.Replace('C:/dependencies/opencl/', "$($env:RUNNER_TEMP.Replace('\','/'))/opencl/") | Set-Content -Path $f -Encoding UTF8 + } + + - name: Download OpenBLAS + id: get_openblas + if: ${{ matrix.build == 'openblas' }} + run: | + curl.exe -o $env:RUNNER_TEMP/openblas.zip -L "https://github.com/xianyi/OpenBLAS/releases/download/v${env:OPENBLAS_VERSION}/OpenBLAS-${env:OPENBLAS_VERSION}-x64.zip" + curl.exe -o $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt -L "https://github.com/xianyi/OpenBLAS/raw/v${env:OPENBLAS_VERSION}/LICENSE" + mkdir $env:RUNNER_TEMP/openblas + tar.exe -xvf $env:RUNNER_TEMP/openblas.zip -C $env:RUNNER_TEMP/openblas + $vcdir = $(vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath) + $msvc = $(join-path $vcdir $('VC\Tools\MSVC\'+$(gc -raw $(join-path $vcdir 'VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt')).Trim())) + $lib = $(join-path $msvc 'bin\Hostx64\x64\lib.exe') + & $lib /machine:x64 "/def:${env:RUNNER_TEMP}/openblas/lib/libopenblas.def" "/out:${env:RUNNER_TEMP}/openblas/lib/openblas.lib" /name:openblas.dll + - name: Build id: cmake_build run: | @@ -171,6 +213,21 @@ jobs: cd build cmake .. ${{ matrix.defines }} cmake --build . --config Release + cp ../LICENSE ./bin/Release/llama.cpp.txt + + - name: Add clblast.dll + id: add_clblast_dll + if: ${{ matrix.build == 'clblast' }} + run: | + cp $env:RUNNER_TEMP/clblast/lib/clblast.dll ./build/bin/Release + cp $env:RUNNER_TEMP/CLBlast.LICENSE.txt ./build/bin/Release/CLBlast-${env:CLBLAST_VERSION}.txt + + - name: Add libopenblas.dll + id: add_libopenblas_dll + if: ${{ matrix.build == 'openblas' }} + run: | + cp $env:RUNNER_TEMP/openblas/bin/libopenblas.dll ./build/bin/Release/openblas.dll + cp $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt ./build/bin/Release/OpenBLAS-${env:OPENBLAS_VERSION}.txt - name: Check AVX512F support id: check_avx512f @@ -187,7 +244,7 @@ jobs: - name: Test id: cmake_test - if: ${{ matrix.build != 'avx512' || env.HAS_AVX512F == '1' }} # Test AVX-512 only when possible + if: ${{ matrix.build != 'clblast' && (matrix.build != 'avx512' || env.HAS_AVX512F == '1') }} # Test AVX-512 only when possible run: | cd build ctest -C Release --verbose From 1f48b0abcfbd6cc99571e42348e0ec97e4be8b93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 8 May 2023 02:42:01 +0200 Subject: [PATCH 19/27] Documented CUDA reproducibility, added warning (#1346) --- README.md | 2 ++ examples/common.cpp | 3 +++ ggml-cuda.cu | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 19cc94aa2..6cbdcbf83 100644 --- a/README.md +++ b/README.md @@ -257,6 +257,8 @@ Building the program with BLAS support may lead to some performance improvements cmake --build . --config Release ``` +Note: Because llama.cpp uses multiple CUDA streams for matrix multiplication results [are not guaranteed to be reproducible](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility). If you need reproducibility, set `GGML_CUDA_MAX_STREAMS` in the file `ggml-cuda.cu` to 1. + ### Prepare Data & Run ```bash diff --git a/examples/common.cpp b/examples/common.cpp index 97eded6ec..f1c3bae13 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -100,6 +100,9 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { arg = argv[i]; if (arg == "-s" || arg == "--seed") { +#if defined(GGML_USE_CUBLAS) + fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n"); +#endif if (++i >= argc) { invalid_param = true; break; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e8a1e77cb..127b352a0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -348,7 +348,7 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { CUDA_CHECK(cudaFree(ptr)); } -#define GGML_CUDA_MAX_STREAMS 8 +#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication. #define GGML_CUDA_MAX_EVENTS 64 static cublasHandle_t g_cublasH = nullptr; static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr }; From 95078cc554fe03d4512363c7e4dec963f0047c72 Mon Sep 17 00:00:00 2001 From: ubik2 Date: Mon, 8 May 2023 04:54:26 -0700 Subject: [PATCH 20/27] convert: add ability to convert safetensors files (#1276) * when loading a safetensors file, ignore the metadata header * check for safetensors files first, and only use PyTorch versions when safetensors aren't available --- convert.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/convert.py b/convert.py index 126beaabc..8f4f0399e 100644 --- a/convert.py +++ b/convert.py @@ -766,7 +766,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) description = f'safetensors begin={begin} end={end} type={data_type} path={path}' return LazyTensor(load, shape, data_type, description) - model = {name: convert(info) for (name, info) in header.items()} + model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) @@ -1051,8 +1051,12 @@ def load_some_model(path: Path) -> ModelPlus: '''Load a model of any supported format.''' # Be extra-friendly and accept either a file or a directory: if path.is_dir(): - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] - files = [file for glob in globs for file in path.glob(glob)] + # Check if it's a set of safetensors files first + files = list(path.glob("model-00001-of-*.safetensors")) + if not files: + # Try the PyTorch patterns too, with lower priority + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] + files = [file for glob in globs for file in path.glob(glob)] if not files: # Try GGML too, but with lower priority, since if both a non-GGML # model and a GGML model exist in the same directory, we assume the From f9a6364912fd0463fddfdbc9ef9f79fdc281570d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 May 2023 17:41:54 +0300 Subject: [PATCH 21/27] llama : require first token to be BOS (#1303) * llama : require first token to be BOS * scripts : add ppl-run-all.sh * perplexity : add BOS for each chunk * readme : update perplexity values after BOS fix * perplexity : add clarifying comments --- .gitignore | 1 + README.md | 32 +++++--------- examples/common.cpp | 4 +- examples/main/main.cpp | 4 +- examples/perplexity/perplexity.cpp | 70 ++++++++++++++++++++---------- llama.cpp | 12 ++++- scripts/ppl-run-all.sh | 43 ++++++++++++++++++ 7 files changed, 116 insertions(+), 50 deletions(-) create mode 100755 scripts/ppl-run-all.sh diff --git a/.gitignore b/.gitignore index 6f275fea4..a5fef3277 100644 --- a/.gitignore +++ b/.gitignore @@ -43,5 +43,6 @@ zig-out/ zig-cache/ ppl-*.txt +qnt-*.txt examples/jeopardy/results.txt diff --git a/README.md b/README.md index 6cbdcbf83..438748a91 100644 --- a/README.md +++ b/README.md @@ -298,17 +298,25 @@ Several quantization methods are supported. They differ in the resulting model d | Model | Measure | F16 | Q4_0 | Q4_1 | Q4_2 | Q5_0 | Q5_1 | Q8_0 | |------:|--------------|-------:|-------:|-------:|-------:|-------:|-------:|-------:| -| 7B | perplexity | 5.9565 | 6.2103 | 6.1286 | 6.1698 | 6.0139 | 5.9934 | 5.9571 | +| 7B | perplexity | 5.9066 | 6.1620 | 6.0910 | 6.1466 | 5.9862 | 5.9481 | 5.9069 | | 7B | file size | 13.0G | 4.0G | 4.8G | 4.0G | 4.4G | 4.8G | 7.1G | | 7B | ms/tok @ 4th | 128 | 56 | 61 | 84 | 91 | 95 | 75 | | 7B | ms/tok @ 8th | 128 | 47 | 55 | 48 | 53 | 59 | 75 | | 7B | bits/weight | 16.0 | 5.0 | 6.0 | 5.0 | 5.5 | 6.0 | 9.0 | -| 13B | perplexity | 5.2455 | 5.3748 | 5.3471 | 5.3433 | 5.2768 | 5.2582 | 5.2458 | +| 13B | perplexity | 5.2543 | 5.3863 | 5.3607 | 5.3513 | 5.2856 | 5.2706 | 5.2548 | | 13B | file size | 25.0G | 7.6G | 9.1G | 7.6G | 8.4G | 9.1G | 14G | | 13B | ms/tok @ 4th | 239 | 104 | 113 | 160 | 176 | 185 | 141 | | 13B | ms/tok @ 8th | 240 | 85 | 99 | 97 | 108 | 117 | 147 | | 13B | bits/weight | 16.0 | 5.0 | 6.0 | 5.0 | 5.5 | 6.0 | 9.0 | +### Perplexity (measuring model quality) + +You can use the `perplexity` example to measure perplexity over a given prompt (lower perplexity is better). +For more information, see [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity). + +The perplexity measurements in table above are done against the `wikitext2` test dataset (https://paperswithcode.com/dataset/wikitext-2), with context length of 512. +The time per token is measured on a MacBook M1 Pro 32GB RAM using 4 and 8 threads. + ### Interactive mode If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter. @@ -407,26 +415,6 @@ If your issue is with model generation quality, then please at least scan the fo - [Aligning language models to follow instructions](https://openai.com/research/instruction-following) - [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) -### Perplexity (measuring model quality) - -You can use the `perplexity` example to measure perplexity over the given prompt. For more background, see [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity). However, in general, lower perplexity is better for LLMs. - -#### Latest measurements - -The latest perplexity scores for the various model sizes and quantizations are being tracked in [discussion #406](https://github.com/ggerganov/llama.cpp/discussions/406). `llama.cpp` is measuring very well compared to the baseline implementations. Quantization has a small negative impact on quality, but, as you can see, running -13B at q4_0 beats the 7B f16 model by a significant amount. - -All measurements are done against the wikitext2 test dataset (https://paperswithcode.com/dataset/wikitext-2), with default options (512 length context). -Note that changing the context length will have a significant impact on perplexity (longer context = better perplexity). -``` -Perplexity - model options -5.5985 - 13B, q4_0 -5.9565 - 7B, f16 -6.3001 - 7B, q4_1 -6.5949 - 7B, q4_0 -6.5995 - 7B, q4_0, --memory_f16 -``` - #### How to run 1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research diff --git a/examples/common.cpp b/examples/common.cpp index f1c3bae13..6af440272 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -438,8 +438,8 @@ std::string gpt_random_prompt(std::mt19937 & rng) { // TODO: not great allocating this every time std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars - std::vector res(text.size() + (int)add_bos); - int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); + std::vector res(text.size() + (int) add_bos); + const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); assert(n >= 0); res.resize(n); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5ac151e14..045093c72 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -313,7 +313,8 @@ int main(int argc, char ** argv) { if (n_past + (int) embd.size() > n_ctx) { const int n_left = n_past - params.n_keep; - n_past = params.n_keep; + // always keep the first token - BOS + n_past = std::max(1, params.n_keep); // insert n_left/2 tokens at the start of embd from last_n_tokens embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); @@ -331,7 +332,6 @@ int main(int argc, char ** argv) { } // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) - // REVIEW if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; for ( ; i < embd.size(); i++) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 299a19999..9212dee5c 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -25,46 +25,68 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` + // BOS tokens will be added for each chunk before eval auto tokens = ::llama_tokenize(ctx, params.prompt, true); - int count = 0; - int seq_count = tokens.size() / params.n_ctx; - int n_vocab = llama_n_vocab(ctx); + int count = 0; + + const int n_chunk = tokens.size() / params.n_ctx; + const int n_vocab = llama_n_vocab(ctx); + const int n_batch = params.n_batch; double nll = 0.0; - fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch); + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); - for (int i = 0; i < seq_count; ++i) { - int start = i * params.n_ctx; - int end = start + params.n_ctx; + for (int i = 0; i < n_chunk; ++i) { + const int start = i * params.n_ctx; + const int end = start + params.n_ctx; + + const int num_batches = (params.n_ctx + n_batch - 1) / n_batch; std::vector logits; - int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch; - auto start_t = std::chrono::high_resolution_clock::now(); + + const auto t_start = std::chrono::high_resolution_clock::now(); + for (int j = 0; j < num_batches; ++j) { - int batch_start = start + j * params.n_batch; - int batch_size = std::min(end - batch_start, params.n_batch); - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) { + const int batch_start = start + j * n_batch; + const int batch_size = std::min(end - batch_start, n_batch); + + // save original token and restore it after eval + const auto token_org = tokens[batch_start]; + + // add BOS token for the first batch of each chunk + if (j == 0) { + tokens[batch_start] = llama_token_bos(); + } + + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } - auto batch_logits = llama_get_logits(ctx); + + // restore the original token in case it was set to BOS + tokens[batch_start] = token_org; + + const auto batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } - auto end_t = std::chrono::high_resolution_clock::now(); + + const auto t_end = std::chrono::high_resolution_clock::now(); + if (i == 0) { - const float seconds = std::chrono::duration(end_t - start_t).count(); - printf("%.2f seconds per pass - ETA ", seconds); - int total_seconds = (int)(seconds * seq_count); + const float t_total = std::chrono::duration(t_end - t_start).count(); + fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk); if (total_seconds >= 60*60) { - printf("%d hours ", total_seconds / (60*60)); + fprintf(stderr, "%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } - printf("%d minutes\n", total_seconds / 60); + fprintf(stderr, "%d minutes\n", total_seconds / 60); } + // We get the logits for all the tokens in the context window (params.n_ctx) // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, - // calculate the perplexity over the last half the window (so the model always has + // calculate the perplexity over the last half of the window (so the model always has // some context to predict the token). // // We rely on the fact that attention in the forward pass only looks at previous @@ -76,10 +98,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // process the entire prompt. for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) { // Calculate probability of next token, given the previous ones. - std::vector tok_logits( - logits.begin() + j * n_vocab, + const std::vector tok_logits( + logits.begin() + (j + 0) * n_vocab, logits.begin() + (j + 1) * n_vocab); - float prob = softmax(tok_logits)[tokens[start + j + 1]]; + + const float prob = softmax(tok_logits)[tokens[start + j + 1]]; + nll += -std::log(prob); ++count; } diff --git a/llama.cpp b/llama.cpp index c36c6ced6..d54fa502c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1052,6 +1052,13 @@ static bool llama_eval_internal( const int n_tokens, const int n_past, const int n_threads) { + + // enforce that the first token is BOS + if (n_past == 0 && tokens[0] != llama_token_bos()) { + fprintf(stderr, "%s: first token must be BOS\n", __func__); + return false; + } + const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -1482,7 +1489,7 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co } if (bos) { - output.push_back(1); + output.push_back(llama_token_bos()); } tokenizer.tokenize(text, output); @@ -2727,11 +2734,14 @@ int llama_eval( fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } + // get a more accurate load time, upon first eval + // TODO: fix this if (!ctx->has_evaluated_once) { ctx->t_load_us = ggml_time_us() - ctx->t_start_us; ctx->has_evaluated_once = true; } + return 0; } diff --git a/scripts/ppl-run-all.sh b/scripts/ppl-run-all.sh new file mode 100755 index 000000000..28f31ca71 --- /dev/null +++ b/scripts/ppl-run-all.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# +# quantize +# + +# 7B +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q4_0.bin q4_0 2>&1 | tee ../qnt-7b-q4_0.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q4_1.bin q4_1 2>&1 | tee ../qnt-7b-q4_1.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q4_2.bin q4_2 2>&1 | tee ../qnt-7b-q4_2.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q5_0.bin q5_0 2>&1 | tee ../qnt-7b-q5_0.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q5_1.bin q5_1 2>&1 | tee ../qnt-7b-q5_1.txt +time ./bin/quantize ../models/7B/ggml-model-f16.bin ../models/7B/ggml-model-q8_0.bin q8_0 2>&1 | tee ../qnt-7b-q8_0.txt + +# 13B +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q4_0.bin q4_0 2>&1 | tee ../qnt-13b-q4_0.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q4_1.bin q4_1 2>&1 | tee ../qnt-13b-q4_1.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q4_2.bin q4_2 2>&1 | tee ../qnt-13b-q4_2.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q5_0.bin q5_0 2>&1 | tee ../qnt-13b-q5_0.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q5_1.bin q5_1 2>&1 | tee ../qnt-13b-q5_1.txt +time ./bin/quantize ../models/13B/ggml-model-f16.bin ../models/13B/ggml-model-q8_0.bin q8_0 2>&1 | tee ../qnt-13b-q8_0.txt + +# +# perplexity +# + +# 7B +time ./bin/perplexity -m ../models/7B/ggml-model-f16.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-f16.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q4_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q4_0.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q4_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q4_1.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q4_2.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q4_2.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q5_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q5_0.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q5_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q5_1.txt +time ./bin/perplexity -m ../models/7B/ggml-model-q8_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-7b-q8_0.txt + +# 13B +time ./bin/perplexity -m ../models/13B/ggml-model-f16.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-f16.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q4_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q4_0.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q4_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q4_1.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q4_2.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q4_2.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q5_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q5_0.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q5_1.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q5_1.txt +time ./bin/perplexity -m ../models/13B/ggml-model-q8_0.bin -f ./wiki.test.raw --no-mmap -t 12 2>&1 | tee ../ppl-13b-q8_0.txt From 003ba2fb4309e2339487564bd249e4fcc8d7ea01 Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Mon, 8 May 2023 16:48:21 +0200 Subject: [PATCH 22/27] llama : fix hparams shadow (#1367) fixes #1363 --- llama.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index d54fa502c..4bba93a11 100644 --- a/llama.cpp +++ b/llama.cpp @@ -970,8 +970,6 @@ static void llama_model_load_internal( // prepare memory for the weights { - const auto & hparams = model.hparams; - const uint32_t n_embd = hparams.n_embd; const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; From fe60904eef4b504685fa0406cb19864ae619fb4f Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Mon, 8 May 2023 21:03:30 +0430 Subject: [PATCH 23/27] readme : add TOC and Pygmalion instructions (#1359) --- README.md | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/README.md b/README.md index 438748a91..f029f06a8 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,39 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ - [Roadmap May 2023](https://github.com/ggerganov/llama.cpp/discussions/1220) - [New quantization methods](https://github.com/ggerganov/llama.cpp#quantization) +
+ Table of Contents +
    +
  1. + Description +
  2. +
  3. + Usage + +
  4. +
  5. Contributing
  6. +
  7. Coding guidelines
  8. +
  9. Docs
  10. +
+
+ ## Description The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quantization on a MacBook @@ -46,6 +79,7 @@ as the main playground for developing new features for the [ggml](https://github - [X] [Vicuna](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5533894) - [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/) - [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy) +- [X] [Pygmalion 7B / Metharme 7B](#using-pygmalion-7b--metharme-7b) **Bindings:** @@ -383,6 +417,19 @@ python3 convert.py models/gpt4all-7B/gpt4all-lora-quantized.bin - The newer GPT4All-J model is not yet supported! +### Using Pygmalion 7B & Metharme 7B + +- Obtain the [LLaMA weights](#obtaining-the-facebook-llama-original-model-and-stanford-alpaca-model-data) +- Obtain the [Pygmalion 7B](https://huggingface.co/PygmalionAI/pygmalion-7b/) or [Metharme 7B](https://huggingface.co/PygmalionAI/metharme-7b) XOR encoded weights +- Convert the LLaMA model with [the latest HF convert script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) +- Merge the XOR files with the converted LLaMA weights by running the [xor_codec](https://huggingface.co/PygmalionAI/pygmalion-7b/blob/main/xor_codec.py) script +- Convert to `ggml` format using the `convert.py` script in this repo: +```bash +python3 convert.py pygmalion-7b/ --outtype q4_1 +``` +> The Pygmalion 7B & Metharme 7B weights are saved in [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) precision. If you wish to convert to `ggml` without quantizating, please specify the `--outtype` as `f32` instead of `f16`. + + ### Obtaining the Facebook LLaMA original model and Stanford Alpaca model data - **Under no circumstances should IPFS, magnet links, or any other links to model downloads be shared anywhere in this repository, including in issues, discussions, or pull requests. They will be immediately deleted.** From 56551bc11f46b2716fdf61bb48ac28414889dc0a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 May 2023 22:52:18 +0300 Subject: [PATCH 24/27] readme : add notice about upcoming breaking change --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index f029f06a8..045f99534 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,14 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ +## ⚠️ TEMPORARY NOTICE ABOUT UPCOMING BREAKING CHANGE ⚠️ + +**The quantization formats will soon be updated: https://github.com/ggerganov/llama.cpp/pull/1305** + +**All `ggml` model files using the old format will not work with the latest `llama.cpp` code after that change is merged** + +--- + **Hot topics:** - [Roadmap May 2023](https://github.com/ggerganov/llama.cpp/discussions/1220) From 41654efea879bbdf4fd794e13335929d4cf0eb90 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Mon, 8 May 2023 19:45:48 -0700 Subject: [PATCH 25/27] Interface improvements and `--multiline-input` (previously `--author-mode`) (#1040) * Interface improvements * Multiline input * Track character width * Works with all characters and control codes + Windows console fixes --- examples/common.cpp | 384 +++++++++++++++++++++++++++++++++++------ examples/common.h | 25 ++- examples/main/main.cpp | 60 +++---- 3 files changed, 374 insertions(+), 95 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 6af440272..23d69e7d5 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -14,20 +14,16 @@ #include #endif -#if defined (_WIN32) +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include #include #include -#pragma comment(lib,"kernel32.lib") -extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); -extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int CodePage, unsigned long dwFlags, - const wchar_t * lpWideCharStr, int cchWideChar, - char * lpMultiByteStr, int cbMultiByte, - const char * lpDefaultChar, bool * lpUsedDefaultChar); -#define CP_UTF8 65001 +#else +#include +#include +#include #endif int32_t get_num_physical_cores() { @@ -269,6 +265,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.interactive_first = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; + } else if (arg == "--multiline-input") { + params.multiline_input = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -359,6 +357,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); + fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); @@ -479,54 +478,339 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { return lctx; } -/* Keep track of current color of output, and emit ANSI code if it changes. */ -void set_console_color(console_state & con_st, console_color_t color) { - if (con_st.use_color && con_st.color != color) { - switch(color) { - case CONSOLE_COLOR_DEFAULT: - printf(ANSI_COLOR_RESET); - break; - case CONSOLE_COLOR_PROMPT: - printf(ANSI_COLOR_YELLOW); - break; - case CONSOLE_COLOR_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); - break; - } - con_st.color = color; - } -} - -#if defined (_WIN32) -void win32_console_init(bool enable_color) { - unsigned long dwMode = 0; - void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) - if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { - hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) - if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { - hConOut = 0; +void console_init(console_state & con_st) { +#if defined(_WIN32) + // Windows-specific console initialization + DWORD dwMode = 0; + con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) { + con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) { + con_st.hConsole = NULL; } } - if (hConOut) { + if (con_st.hConsole) { // Enable ANSI colors on Windows 10+ - if (enable_color && !(dwMode & 0x4)) { - SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) + if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); } // Set console output codepage to UTF8 SetConsoleOutputCP(CP_UTF8); } - void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) - if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { // Set console input codepage to UTF16 _setmode(_fileno(stdin), _O_WTEXT); + + // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + SetConsoleMode(hConIn, dwMode); + } +#else + // POSIX-specific console initialization + struct termios new_termios; + tcgetattr(STDIN_FILENO, &con_st.prev_state); + new_termios = con_st.prev_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + con_st.tty = fopen("/dev/tty", "w+"); + if (con_st.tty != nullptr) { + con_st.out = con_st.tty; + } +#endif + setlocale(LC_ALL, ""); +} + +void console_cleanup(console_state & con_st) { + // Reset console color + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + +#if !defined(_WIN32) + if (con_st.tty != nullptr) { + con_st.out = stdout; + fclose(con_st.tty); + con_st.tty = nullptr; + } + // Restore the terminal settings on POSIX systems + tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); +#endif +} + +/* Keep track of current color of output, and emit ANSI code if it changes. */ +void console_set_color(console_state & con_st, console_color_t color) { + if (con_st.use_color && con_st.color != color) { + fflush(stdout); + switch(color) { + case CONSOLE_COLOR_DEFAULT: + fprintf(con_st.out, ANSI_COLOR_RESET); + break; + case CONSOLE_COLOR_PROMPT: + fprintf(con_st.out, ANSI_COLOR_YELLOW); + break; + case CONSOLE_COLOR_USER_INPUT: + fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + } + con_st.color = color; + fflush(con_st.out); } } -// Convert a wide Unicode string to an UTF8 string -void win32_utf8_encode(const std::wstring & wstr, std::string & str) { - int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); - std::string strTo(size_needed, 0); - WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), &strTo[0], size_needed, NULL, NULL); - str = strTo; -} +char32_t getchar32() { + wchar_t wc = getwchar(); + if (static_cast(wc) == WEOF) { + return WEOF; + } + +#if WCHAR_MAX == 0xFFFF + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + wchar_t low_surrogate = getwchar(); + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; + } + } + if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + return 0xFFFD; // Return the replacement character U+FFFD + } #endif + + return static_cast(wc); +} + +void pop_cursor(console_state & con_st) { +#if defined(_WIN32) + if (con_st.hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(con_st.hConsole, newCursorPosition); + return; + } +#endif + putc('\b', con_st.out); +} + +int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + return 1; +#else + return wcwidth(codepoint); +#endif +} + +int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // we can trust expectedWidth if we've got one + if (expectedWidth >= 0 || con_st.tty == nullptr) { + fwrite(utf8_codepoint, length, 1, con_st.out); + return expectedWidth; + } + + fputs("\033[6n", con_st.tty); // Query cursor position + int x1, x2, y1, y2; + int results = 0; + results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, con_st.tty); + + fputs("\033[6n", con_st.tty); // Query cursor position + results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif +} + +void replace_last(console_state & con_st, char ch) { +#if defined(_WIN32) + pop_cursor(con_st); + put_codepoint(con_st, &ch, 1, 1); +#else + fprintf(con_st.out, "\b%c", ch); +#endif +} + +void append_utf8(char32_t ch, std::string & out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } +} + +// Helper function to remove the last UTF-8 character from a string +void pop_back_utf8_char(std::string & line) { + if (line.empty()) { + return; + } + + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character + } + line.erase(pos); +} + +bool console_readline(console_state & con_st, std::string & line) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + if (con_st.out != stdout) { + fflush(stdout); + } + + line.clear(); + std::vector widths; + bool is_special_char = false; + bool end_of_stream = false; + + char32_t input_char; + while (true) { + fflush(con_st.out); // Ensure all output is displayed before waiting for input + input_char = getchar32(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == WEOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + replace_last(con_st, line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + char32_t code = getchar32(); + if (code == '[' || code == 0x1B) { + // Discard the rest of the escape sequence + while ((code = getchar32()) != WEOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!widths.empty()) { + int count; + do { + count = widths.back(); + widths.pop_back(); + // Move cursor back, print space, and move cursor back again + for (int i = 0; i < count; i++) { + replace_last(con_st, ' '); + pop_cursor(con_st); + } + pop_back_utf8_char(line); + } while (count == 0 && !widths.empty()); + } + } else { + int offset = line.length(); + append_utf8(input_char, line); + int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } + widths.push_back(width); + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + console_set_color(con_st, CONSOLE_COLOR_PROMPT); + replace_last(con_st, line.back()); + is_special_char = true; + } + } + + bool has_more = con_st.multiline_input; + if (is_special_char) { + replace_last(con_st, ' '); + pop_cursor(con_st); + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + fputc('\n', con_st.out); + has_more = !has_more; + } else { + // llama will just eat the single space, it won't act as a space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + pop_cursor(con_st); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + fputc('\n', con_st.out); + } + } + + fflush(con_st.out); + return has_more; +} diff --git a/examples/common.h b/examples/common.h index 842e1516f..43f1cc9ef 100644 --- a/examples/common.h +++ b/examples/common.h @@ -10,6 +10,11 @@ #include #include +#if !defined (_WIN32) +#include +#include +#endif + // // CLI argument parsing // @@ -56,6 +61,7 @@ struct gpt_params { bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately + bool multiline_input = false; // reverse the usage of `\` bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token @@ -104,13 +110,20 @@ enum console_color_t { }; struct console_state { + bool multiline_input = false; bool use_color = false; console_color_t color = CONSOLE_COLOR_DEFAULT; + + FILE* out = stdout; +#if defined (_WIN32) + void* hConsole; +#else + FILE* tty = nullptr; + termios prev_state; +#endif }; -void set_console_color(console_state & con_st, console_color_t color); - -#if defined (_WIN32) -void win32_console_init(bool enable_color); -void win32_utf8_encode(const std::wstring & wstr, std::string & str); -#endif +void console_init(console_state & con_st); +void console_cleanup(console_state & con_st); +void console_set_color(console_state & con_st, console_color_t color); +bool console_readline(console_state & con_st, std::string & line); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 045093c72..6e1172a48 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -35,12 +35,12 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; } else { + console_cleanup(con_st); + printf("\n"); llama_print_timings(*g_ctx); _exit(130); } @@ -59,10 +59,9 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) con_st.use_color = params.use_color; - -#if defined (_WIN32) - win32_console_init(params.use_color); -#endif + con_st.multiline_input = params.multiline_input; + console_init(con_st); + atexit([]() { console_cleanup(con_st); }); if (params.perplexity) { printf("\n************\n"); @@ -275,12 +274,21 @@ int main(int argc, char ** argv) { std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { + const char *control_message; + if (con_st.multiline_input) { + control_message = " - To return control to LLaMa, end your input with '\\'.\n" + " - To return control without starting a new line, end your input with '/'.\n"; + } else { + control_message = " - Press Return to return control to LLaMa.\n" + " - To return control without starting a new line, end your input with '/'.\n" + " - If you want to submit another line, end your input with '\\'.\n"; + } fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif - " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n\n"); + "%s\n", control_message); + is_interacting = params.interactive_first; } @@ -299,7 +307,7 @@ int main(int argc, char ** argv) { int n_session_consumed = 0; // the first thing we will do is to output the prompt, so set color accordingly - set_console_color(con_st, CONSOLE_COLOR_PROMPT); + console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; @@ -498,7 +506,7 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (input_echo && (int)embd_inp.size() == n_consumed) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } // in interactive mode, and not currently processing queued inputs; @@ -518,17 +526,12 @@ int main(int argc, char ** argv) { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { is_interacting = true; is_antiprompt = true; - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - fflush(stdout); break; } } } if (n_past > 0 && is_interacting) { - // potentially set color to indicate we are taking user input - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - if (params.instruct) { printf("\n> "); } @@ -542,31 +545,12 @@ int main(int argc, char ** argv) { std::string line; bool another_line = true; do { -#if defined(_WIN32) - std::wstring wline; - if (!std::getline(std::wcin, wline)) { - // input stream is bad or EOF received - return 0; - } - win32_utf8_encode(wline, line); -#else - if (!std::getline(std::cin, line)) { - // input stream is bad or EOF received - return 0; - } -#endif - if (!line.empty()) { - if (line.back() == '\\') { - line.pop_back(); // Remove the continue character - } else { - another_line = false; - } - buffer += line + '\n'; // Append the line to the result - } + another_line = console_readline(con_st, line); + buffer += line; } while (another_line); // done taking input, reset color - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back @@ -622,7 +606,5 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - return 0; } From 9f8dbc4787f32cd9c13b5c647d497d13c1a06db2 Mon Sep 17 00:00:00 2001 From: Sami Farin <3876865+Safari77@users.noreply.github.com> Date: Tue, 9 May 2023 15:29:20 +0300 Subject: [PATCH 26/27] =?UTF-8?q?use=20pause=20asm=20insn=20in=20busyloop?= =?UTF-8?q?=20to=20run=20the=20CPU=20(13600K)=2010=20=C2=B0C=20cooler=20(#?= =?UTF-8?q?1314)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * use pause asm insn in busyloop to run the CPU (13600K) 10 °C cooler Tested with a 13B model. * use _mm_pause() in busyloop * use _mm_pause() in busyloop on x86_64 to reduce power consumption --- ggml.c | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml.c b/ggml.c index 1b89bdd89..4e309df8a 100644 --- a/ggml.c +++ b/ggml.c @@ -11663,7 +11663,11 @@ typedef int ggml_lock_t; #define ggml_lock_init(x) UNUSED(x) #define ggml_lock_destroy(x) UNUSED(x) +#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) +#define ggml_lock_lock(x) _mm_pause() +#else #define ggml_lock_lock(x) UNUSED(x) +#endif #define ggml_lock_unlock(x) UNUSED(x) #define GGML_LOCK_INITIALIZER 0 From e6a46b0ed1884c77267dc70693183e3b7164e0e0 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Tue, 9 May 2023 10:53:28 -0700 Subject: [PATCH 27/27] Locale fix for Windows (#1379) --- examples/common.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index 23d69e7d5..7aa77587b 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -520,8 +520,9 @@ void console_init(console_state & con_st) { if (con_st.tty != nullptr) { con_st.out = con_st.tty; } -#endif + setlocale(LC_ALL, ""); +#endif } void console_cleanup(console_state & con_st) {