diff --git a/README.md b/README.md index 1fe5b5426..dae1bf1b8 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ **Hot topics:** -- RMSNorm implementation / fixes: https://github.com/ggerganov/llama.cpp/issues/173 +- [Added Alpaca support](https://github.com/ggerganov/llama.cpp#instruction-mode-with-alpaca) - Cache input prompts for faster initialization: https://github.com/ggerganov/llama.cpp/issues/64 - Create a `llama.cpp` logo: https://github.com/ggerganov/llama.cpp/issues/105 @@ -147,7 +147,7 @@ python3 -m pip install torch numpy sentencepiece python3 convert-pth-to-ggml.py models/7B/ 1 # quantize the model to 4-bits -./quantize.sh 7B +python3 quantize.py 7B # run the inference ./main -m ./models/7B/ggml-model-q4_0.bin -n 128 @@ -176,21 +176,51 @@ In this mode, you can always interrupt generation by pressing Ctrl+C and enter o Here is an example few-shot interaction, invoked with the command ``` -./main -m ./models/13B/ggml-model-q4_0.bin -n 256 --repeat_penalty 1.0 --color -i -r "User:" \ - -p \ -"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. - -User: Hello, Bob. -Bob: Hello. How may I help you today? -User: Please tell me the largest city in Europe. -Bob: Sure. The largest city in Europe is Moscow, the capital of Russia. -User:" +./main -m ./models/13B/ggml-model-q4_0.bin -n 256 --repeat_penalty 1.0 --color -i -r "User:" -f prompts/chat-with-bob.txt ``` Note the use of `--color` to distinguish between user input and generated text. ![image](https://user-images.githubusercontent.com/1991296/224575029-2af3c7dc-5a65-4f64-a6bb-517a532aea38.png) +### Instruction mode with Alpaca + +First, download the `ggml` Alpaca model into the `./models` folder: + +``` +# use one of these +# NOTE: these are copied from the alpaca.cpp repo - not sure how long these will work +# TODO: add a script to simplify the download +curl -o ggml-alpaca-7b-q4.bin -C - https://gateway.estuary.tech/gw/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC +curl -o ggml-alpaca-7b-q4.bin -C - https://ipfs.io/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC +curl -o ggml-alpaca-7b-q4.bin -C - https://cloudflare-ipfs.com/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC +``` + +Now run the `main` tool like this: + +``` +./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins +``` + +Sample run: + +``` +== Running in interactive mode. == + - Press Ctrl+C to interject at any time. + - Press Return to return control to LLaMa. + - If you want to submit another line, end your input in '\'. + + Below is an instruction that describes a task. Write a response that appropriately completes the request. + +> How many letters are there in the English alphabet? +There 26 letters in the English Alphabet +> What is the most common way of transportation in Amsterdam? +The majority (54%) are using public transit. This includes buses, trams and metros with over 100 lines throughout the city which make it very accessible for tourists to navigate around town as well as locals who commute by tram or metro on a daily basis +> List 5 words that start with "ca". +cadaver, cauliflower, cabbage (vegetable), catalpa (tree) and Cailleach. +> +``` + ### Android You can easily run `llama.cpp` on Android device with [termux](https://play.google.com/store/apps/details?id=com.termux). diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d0eb213c8..c1941a811 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -16,7 +16,7 @@ # At the start of the ggml file we write the model parameters # and vocabulary. # -import os +import argparse import sys import json import struct @@ -24,136 +24,81 @@ import numpy as np import torch from sentencepiece import SentencePieceProcessor -if len(sys.argv) < 3: - print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n") - print(" ftype == 0 -> float32") - print(" ftype == 1 -> float16") - sys.exit(1) +def parse_args(): -# output in the same directory as the model -dir_model = sys.argv[1] - -fname_hparams = sys.argv[1] + "/params.json" -fname_tokenizer = sys.argv[1] + "/../tokenizer.model" + parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') + parser.add_argument('dir_model', help='directory containing the model checkpoint') + parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)') + return parser.parse_args() def get_n_parts(dim): - if dim == 4096: - return 1 - elif dim == 5120: - return 2 - elif dim == 6656: - return 4 - elif dim == 8192: - return 8 - else: - print("Invalid dim: " + str(dim)) + + mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8} + n_parts = mappings.get(dim) + if n_parts is None: + print(f"Invalid dim: {dim}") sys.exit(1) -# possible data types -# ftype == 0 -> float32 -# ftype == 1 -> float16 -# -# map from ftype to string -ftype_str = ["f32", "f16"] + print(f"n_parts = {n_parts}\n") + return n_parts -ftype = 1 -if len(sys.argv) > 2: - ftype = int(sys.argv[2]) - if ftype < 0 or ftype > 1: - print("Invalid ftype: " + str(ftype)) - sys.exit(1) - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" +def load_hparams_and_tokenizer(dir_model): -if os.path.exists(fname_out): - print(f"Skip conversion, it already exists: {fname_out}") - sys.exit(0) + fname_hparams = f"{dir_model}/params.json" + fname_tokenizer = f"{dir_model}/../tokenizer.model" -with open(fname_hparams, "r") as f: - hparams = json.load(f) + with open(fname_hparams, "r") as f: + hparams = json.load(f) + print(hparams) -tokenizer = SentencePieceProcessor(fname_tokenizer) + tokenizer = SentencePieceProcessor(fname_tokenizer) + hparams.update({"vocab_size": tokenizer.vocab_size()}) -hparams.update({"vocab_size": tokenizer.vocab_size()}) + return hparams, tokenizer -n_parts = get_n_parts(hparams["dim"]) +def write_header(fout, hparams, ftype): -print(hparams) -print('n_parts = ', n_parts) + keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"] + values = [ + 0x67676d6c, # magic: ggml in hex + *[hparams[key] for key in keys], + hparams["dim"] // hparams["n_heads"], # rot (obsolete) + ftype + ] + fout.write(struct.pack("i" * len(values), *values)) -for p in range(n_parts): - print('Processing part ', p) +def write_tokens(fout, tokenizer): - #fname_model = sys.argv[1] + "/consolidated.00.pth" - fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth" - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" - if (p > 0): - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p) - - model = torch.load(fname_model, map_location="cpu") - - fout = open(fname_out, "wb") - - fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex - fout.write(struct.pack("i", hparams["vocab_size"])) - fout.write(struct.pack("i", hparams["dim"])) - fout.write(struct.pack("i", hparams["multiple_of"])) - fout.write(struct.pack("i", hparams["n_heads"])) - fout.write(struct.pack("i", hparams["n_layers"])) - fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete) - fout.write(struct.pack("i", ftype)) - - # Is this correct?? for i in range(tokenizer.vocab_size()): if tokenizer.is_unknown(i): - # "" token (translated as ??) text = " \u2047 ".encode("utf-8") - fout.write(struct.pack("i", len(text))) - fout.write(text) elif tokenizer.is_control(i): - # ""/"" tokens - fout.write(struct.pack("i", 0)) + text = b"" elif tokenizer.is_byte(i): - # "" tokens (which may be invalid UTF-8) piece = tokenizer.id_to_piece(i) if len(piece) != 6: - print("Invalid token: " + piece) + print(f"Invalid token: {piece}") sys.exit(1) byte_value = int(piece[3:-1], 16) - fout.write(struct.pack("i", 1)) - fout.write(struct.pack("B", byte_value)) + text = struct.pack("B", byte_value) else: - # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces. text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - fout.write(struct.pack("i", len(text))) - fout.write(text) + fout.write(struct.pack("i", len(text))) + fout.write(text) - for k, v in model.items(): - name = k - shape = v.shape +def process_and_write_variables(fout, model, ftype): - # skip layers.X.attention.inner_attention.rope.freqs - if name[-5:] == "freqs": + for name, datao in model.items(): + + if name.endswith("freqs"): continue - print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) + shape = datao.shape - #data = tf.train.load_variable(dir_model, name).squeeze() - data = v.numpy().squeeze() - n_dims = len(data.shape); + print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}") - # for efficiency - transpose some matrices - # "model/h.*/attn/c_attn/w" - # "model/h.*/attn/c_proj/w" - # "model/h.*/mlp/c_fc/w" - # "model/h.*/mlp/c_proj/w" - #if name[-14:] == "/attn/c_attn/w" or \ - # name[-14:] == "/attn/c_proj/w" or \ - # name[-11:] == "/mlp/c_fc/w" or \ - # name[-13:] == "/mlp/c_proj/w": - # print(" Transposing") - # data = data.transpose() - - dshape = data.shape + data = datao.numpy().squeeze() + n_dims = len(shape) # default type is fp16 ftype_cur = 1 @@ -164,18 +109,40 @@ for p in range(n_parts): # header sname = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) - for i in range(n_dims): - fout.write(struct.pack("i", dshape[n_dims - 1 - i])) - fout.write(sname); + fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur)) + for dim in reversed(data.shape): + fout.write(struct.pack("i", dim)) + fout.write(sname) - # data + # data output to file data.tofile(fout) - # I hope this deallocates the memory .. - model = None +def main(): - fout.close() + args = parse_args() + dir_model = args.dir_model + ftype = args.ftype + ftype_str = ["f32", "f16"] - print("Done. Output file: " + fname_out + ", (part ", p, ")") - print("") + hparams, tokenizer = load_hparams_and_tokenizer(dir_model) + n_parts = get_n_parts(hparams["dim"]) + + for p in range(n_parts): + + print(f"Processing part {p}\n") + + fname_model = f"{dir_model}/consolidated.0{p}.pth" + fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}" + + model = torch.load(fname_model, map_location="cpu") + + with open(fname_out, "wb") as fout: + write_header(fout, hparams, ftype) + write_tokens(fout, tokenizer) + process_and_write_variables(fout, model, ftype) + + del model + print(f"Done. Output file: {fname_out}, (part {p})\n") + +if __name__ == "__main__": + main() diff --git a/ggml.c b/ggml.c index 4fb83adbd..4813f74c8 100644 --- a/ggml.c +++ b/ggml.c @@ -5556,7 +5556,7 @@ static void ggml_compute_forward_rms_norm_f32( const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; - const ggml_float eps = 1e-5f; // TODO: make this a parameter + const ggml_float eps = 1e-6f; // TODO: make this a parameter // TODO: optimize for (int i03 = 0; i03 < ne03; i03++) { @@ -5572,7 +5572,7 @@ static void ggml_compute_forward_rms_norm_f32( mean /= ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - + memcpy(y, x, ne00 * sizeof(float)); // for (int i00 = 0; i00 < ne00; i00++) { // y[i00] = x[i00]; diff --git a/main.cpp b/main.cpp index 58b2cb68d..38d11924d 100644 --- a/main.cpp +++ b/main.cpp @@ -27,6 +27,8 @@ #define ANSI_COLOR_RESET "\x1b[0m" #define ANSI_BOLD "\x1b[1m" +static const int EOS_TOKEN_ID = 2; + // determine number of model parts based on the dimension static const std::map LLAMA_N_PARTS = { { 4096, 1 }, @@ -86,7 +88,7 @@ struct llama_model { }; // load the model's weights from a file -bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) { +bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); std::vector f_buf(1024*1024); @@ -176,8 +178,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab } } - const ggml_type wtype2 = GGML_TYPE_F32; - auto & ctx = model.ctx; size_t ctx_size = 0; @@ -209,8 +209,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2 ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3 - ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k - ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v ctx_size += (5 + 10*n_layer)*256; // object overhead @@ -237,7 +237,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; model.layers.resize(n_layer); @@ -296,8 +295,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab const int n_mem = n_layer*n_ctx; const int n_elements = n_embd*n_mem; - model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); + model.memory_k = ggml_new_tensor_1d(ctx, memory_type, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, memory_type, n_elements); const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); @@ -539,9 +538,7 @@ bool llama_eval( const int n_vocab = hparams.n_vocab; const int n_rot = hparams.n_embd/hparams.n_head; - const int d_key = n_embd/n_head; - - // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case + // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case // static size_t buf_size = hparams.n_ctx*1024*1024; static size_t buf_size = 512u*1024*1024; static void * buf = malloc(buf_size); @@ -752,6 +749,7 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { printf(ANSI_COLOR_RESET); + printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; @@ -792,7 +790,7 @@ int main(int argc, char ** argv) { if (gpt_params_parse(argc, argv, params) == false) { return 1; } - + if (params.n_ctx > 2048) { fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);" "expect poor results\n", __func__, params.n_ctx); @@ -805,7 +803,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); - if (params.prompt.empty()) { + if (params.random_prompt) { params.prompt = gpt_random_prompt(rng); } @@ -819,8 +817,9 @@ int main(int argc, char ** argv) { // load the model { + const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; const int64_t t_start_us = ggml_time_us(); - if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { + if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } @@ -849,6 +848,16 @@ int main(int argc, char ** argv) { params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); + // prefix & suffix for instruct mode + const std::vector inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true); + const std::vector inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false); + + // in instruct mode, we inject a prefix and a suffix to each input by the user + if (params.instruct) { + params.interactive = true; + params.antiprompt.push_back("### Instruction:\n\n"); + } + // tokenize the reverse prompt std::vector> antipromptv_inp; @@ -856,6 +865,11 @@ int main(int argc, char ** argv) { antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false)); } + // enable interactive mode if reverse prompt is specified + if (!antipromptv_inp.size()) { + params.interactive = true; + } + fprintf(stderr, "\n"); fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); @@ -901,31 +915,27 @@ int main(int argc, char ** argv) { std::vector last_n_tokens(last_n_size); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); - if (params.interactive) { fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n"); + " - If you want to submit another line, end your input in '\\'.\n\n"); + is_interacting = true; } - int remaining_tokens = params.n_predict; int input_consumed = 0; bool input_noecho = false; - // prompt user immediately after the starting prompt has been loaded - if (params.interactive_start) { - is_interacting = true; - } + int remaining_tokens = params.n_predict; // set the color for the prompt which will be output initially if (params.use_color) { printf(ANSI_COLOR_YELLOW); } - while (remaining_tokens > 0) { + while (remaining_tokens > 0 || params.interactive) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); @@ -955,6 +965,11 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); + if (params.ignore_eos) { + // set the logit of the eos token to zero to avoid sampling it + logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0; + } + id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); last_n_tokens.erase(last_n_tokens.begin()); @@ -978,13 +993,13 @@ int main(int argc, char ** argv) { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(embd_inp[input_consumed]); ++input_consumed; - if (embd.size() > params.n_batch) { + if ((int) embd.size() >= params.n_batch) { break; } } // reset color to default if we there is no pending user input - if (!input_noecho && params.use_color && embd_inp.size() == input_consumed) { + if (!input_noecho && params.use_color && (int) embd_inp.size() == input_consumed) { printf(ANSI_COLOR_RESET); } } @@ -1009,19 +1024,26 @@ int main(int argc, char ** argv) { } } if (is_interacting) { + if (params.instruct) { + input_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + + printf("\n> "); + } + // currently being interactive - bool another_line=true; + bool another_line = true; while (another_line) { fflush(stdout); char buf[256] = {0}; int n_read; - if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); + if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) { // presumable empty line, consume the newline std::ignore = scanf("%*c"); n_read=0; } - if(params.use_color) printf(ANSI_COLOR_RESET); + if (params.use_color) printf(ANSI_COLOR_RESET); if (n_read > 0 && buf[n_read-1]=='\\') { another_line = true; @@ -1036,6 +1058,10 @@ int main(int argc, char ** argv) { std::vector line_inp = ::llama_tokenize(vocab, buf, false); embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + if (params.instruct) { + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + remaining_tokens -= line_inp.size(); input_noecho = true; // do not echo this again @@ -1046,9 +1072,19 @@ int main(int argc, char ** argv) { } // end of text token - if (embd.back() == 2) { - fprintf(stderr, " [end of text]\n"); - break; + if (embd.back() == EOS_TOKEN_ID) { + if (params.interactive) { + is_interacting = true; + } else { + fprintf(stderr, " [end of text]\n"); + break; + } + } + + // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. + if (params.interactive && remaining_tokens <= 0) { + remaining_tokens = params.n_predict; + is_interacting = true; } } diff --git a/prompts/alpaca.txt b/prompts/alpaca.txt new file mode 100644 index 000000000..2224bdeb0 --- /dev/null +++ b/prompts/alpaca.txt @@ -0,0 +1 @@ +Below is an instruction that describes a task. Write a response that appropriately completes the request. diff --git a/prompts/chat-with-bob.txt b/prompts/chat-with-bob.txt new file mode 100644 index 000000000..009da39ae --- /dev/null +++ b/prompts/chat-with-bob.txt @@ -0,0 +1,7 @@ +Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. + +User: Hello, Bob. +Bob: Hello. How may I help you today? +User: Please tell me the largest city in Europe. +Bob: Sure. The largest city in Europe is Moscow, the capital of Russia. +User: diff --git a/quantize.py b/quantize.py new file mode 100644 index 000000000..6320b0a26 --- /dev/null +++ b/quantize.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +"""Script to execute the "quantize" script on a given set of models.""" + +import subprocess +import argparse +import glob +import sys +import os + + +def main(): + """Update the quantize binary name depending on the platform and parse + the command line arguments and execute the script. + """ + + if "linux" in sys.platform or "darwin" in sys.platform: + quantize_script_binary = "quantize" + + elif "win32" in sys.platform or "cygwin" in sys.platform: + quantize_script_binary = "quantize.exe" + + else: + print("WARNING: Unknown platform. Assuming a UNIX-like OS.\n") + quantize_script_binary = "quantize" + + parser = argparse.ArgumentParser( + prog='python3 quantize.py', + description='This script quantizes the given models by applying the ' + f'"{quantize_script_binary}" script on them.' + ) + parser.add_argument( + 'models', nargs='+', choices=('7B', '13B', '30B', '65B'), + help='The models to quantize.' + ) + parser.add_argument( + '-r', '--remove-16', action='store_true', dest='remove_f16', + help='Remove the f16 model after quantizing it.' + ) + parser.add_argument( + '-m', '--models-path', dest='models_path', + default=os.path.join(os.getcwd(), "models"), + help='Specify the directory where the models are located.' + ) + parser.add_argument( + '-q', '--quantize-script-path', dest='quantize_script_path', + default=os.path.join(os.getcwd(), quantize_script_binary), + help='Specify the path to the "quantize" script.' + ) + + # TODO: Revise this code + # parser.add_argument( + # '-t', '--threads', dest='threads', type='int', + # default=os.cpu_count(), + # help='Specify the number of threads to use to quantize many models at ' + # 'once. Defaults to os.cpu_count().' + # ) + + args = parser.parse_args() + + if not os.path.isfile(args.quantize_script_path): + print( + f'The "{quantize_script_binary}" script was not found in the ' + "current location.\nIf you want to use it from another location, " + "set the --quantize-script-path argument from the command line." + ) + sys.exit(1) + + for model in args.models: + # The model is separated in various parts + # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...) + f16_model_path_base = os.path.join( + args.models_path, model, "ggml-model-f16.bin" + ) + + f16_model_parts_paths = map( + lambda filename: os.path.join(f16_model_path_base, filename), + glob.glob(f"{f16_model_path_base}*") + ) + + for f16_model_part_path in f16_model_parts_paths: + if not os.path.isfile(f16_model_part_path): + print( + f"The f16 model {os.path.basename(f16_model_part_path)} " + f"was not found in {args.models_path}{os.path.sep}{model}" + ". If you want to use it from another location, set the " + "--models-path argument from the command line." + ) + sys.exit(1) + + __run_quantize_script( + args.quantize_script_path, f16_model_part_path + ) + + if args.remove_f16: + os.remove(f16_model_part_path) + + +# This was extracted to a top-level function for parallelization, if +# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406 + +def __run_quantize_script(script_path, f16_model_part_path): + """Run the quantize script specifying the path to it and the path to the + f16 model to quantize. + """ + + new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0") + subprocess.run( + [script_path, f16_model_part_path, new_quantized_model_path, "2"], + check=True + ) + + +if __name__ == "__main__": + try: + main() + + except subprocess.CalledProcessError: + print("\nAn error ocurred while trying to quantize the models.") + sys.exit(1) + + except KeyboardInterrupt: + sys.exit(0) + + else: + print("\nSuccesfully quantized all models.") diff --git a/quantize.sh b/quantize.sh deleted file mode 100755 index 6194649b3..000000000 --- a/quantize.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env bash - -if ! [[ "$1" =~ ^[0-9]{1,2}B$ ]]; then - echo - echo "Usage: quantize.sh 7B|13B|30B|65B [--remove-f16]" - echo - exit 1 -fi - -for i in `ls models/$1/ggml-model-f16.bin*`; do - ./quantize "$i" "${i/f16/q4_0}" 2 - if [[ "$2" == "--remove-f16" ]]; then - rm "$i" - fi -done diff --git a/utils.cpp b/utils.cpp index 19665860e..08d5c6ba6 100644 --- a/utils.cpp +++ b/utils.cpp @@ -38,19 +38,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; } else if (arg == "-f" || arg == "--file") { - std::ifstream file(argv[++i]); - - std::copy(std::istreambuf_iterator(file), - std::istreambuf_iterator(), - back_inserter(params.prompt)); - + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (params.prompt.back() == '\n') { + params.prompt.pop_back(); + } } else if (arg == "-n" || arg == "--n_predict") { params.n_predict = std::stoi(argv[++i]); } else if (arg == "--top_k") { params.top_k = std::stoi(argv[++i]); } else if (arg == "-c" || arg == "--ctx_size") { params.n_ctx = std::stoi(argv[++i]); + } else if (arg == "--memory_f16") { + params.memory_f16 = true; } else if (arg == "--top_p") { params.top_p = std::stof(argv[++i]); } else if (arg == "--temp") { @@ -65,16 +65,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.model = argv[++i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; - } else if (arg == "--interactive-start") { - params.interactive = true; - params.interactive_start = true; + } else if (arg == "-ins" || arg == "--instruct") { + params.instruct = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "-r" || arg == "--reverse-prompt") { params.antiprompt.push_back(argv[++i]); + } else if (arg == "--ignore-eos") { + params.ignore_eos = true; } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); + } else if (arg == "--random-prompt") { + params.random_prompt = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, params); @@ -85,13 +88,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return true; } -void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { +void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -i, --interactive run in interactive mode\n"); - fprintf(stderr, " --interactive-start run in interactive mode and poll user input at startup\n"); + fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); @@ -99,7 +102,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); - fprintf(stderr, " prompt to start generation with (default: random)\n"); + fprintf(stderr, " prompt to start generation with (default: empty)\n"); + fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); @@ -108,6 +112,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); + fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n"); + fprintf(stderr, " --memory_f16 use f16 instead of f32 for memory key+value\n"); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -m FNAME, --model FNAME\n"); @@ -399,7 +405,7 @@ gpt_vocab::id llama_sample_top_p_top_k( logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); } else { logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); - } + } } else { logits_id.push_back(std::make_pair(logits[i]*scale, i)); } diff --git a/utils.h b/utils.h index f14feec74..49658f7d9 100644 --- a/utils.h +++ b/utils.h @@ -18,6 +18,7 @@ struct gpt_params { int32_t n_predict = 128; // new tokens to predict int32_t repeat_last_n = 64; // last n tokens to penalize int32_t n_ctx = 512; //context size + bool memory_f16 = false; // use f16 instead of f32 for memory kv // sampling parameters int32_t top_k = 40; @@ -27,14 +28,18 @@ struct gpt_params { int32_t n_batch = 8; // batch size for prompt processing - std::string model = "models/lamma-7B/ggml-model.bin"; // model path - std::string prompt; + std::string model = "models/lamma-7B/ggml-model.bin"; // model path + std::string prompt = ""; + + bool random_prompt = false; bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode bool interactive_start = false; // reverse prompt immediately std::vector antiprompt; // string upon seeing which more user input is prompted + bool instruct = false; // instruction mode (used for Alpaca models) + bool ignore_eos = false; // do not stop generating after eos }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params);