diff --git a/Makefile b/Makefile index 77dec0e0c..b24f8c44e 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ endif # CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC +CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++17 -fPIC LDFLAGS = # OS specific diff --git a/alpaca.sh b/alpaca.sh new file mode 100755 index 000000000..284989bc0 --- /dev/null +++ b/alpaca.sh @@ -0,0 +1,6 @@ +#!/bin/bash +# +# Temporary script - will be removed in the future +# + +./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins --top_k 10000 --temp 0.96 --repeat_penalty 1 -t 7 diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d0eb213c8..42f537769 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -16,7 +16,7 @@ # At the start of the ggml file we write the model parameters # and vocabulary. # -import os +import argparse import sys import json import struct @@ -24,136 +24,83 @@ import numpy as np import torch from sentencepiece import SentencePieceProcessor -if len(sys.argv) < 3: - print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n") - print(" ftype == 0 -> float32") - print(" ftype == 1 -> float16") - sys.exit(1) +def parse_args(): -# output in the same directory as the model -dir_model = sys.argv[1] - -fname_hparams = sys.argv[1] + "/params.json" -fname_tokenizer = sys.argv[1] + "/../tokenizer.model" + parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') + parser.add_argument('dir_model', help='directory containing the model checkpoint') + parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)') + return parser.parse_args() def get_n_parts(dim): - if dim == 4096: - return 1 - elif dim == 5120: - return 2 - elif dim == 6656: - return 4 - elif dim == 8192: - return 8 - else: - print("Invalid dim: " + str(dim)) + + mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8} + n_parts = mappings.get(dim) + if n_parts is None: + print(f"Invalid dim: {dim}") sys.exit(1) -# possible data types -# ftype == 0 -> float32 -# ftype == 1 -> float16 -# -# map from ftype to string -ftype_str = ["f32", "f16"] + print(f"n_parts = {n_parts}\n") + return n_parts -ftype = 1 -if len(sys.argv) > 2: - ftype = int(sys.argv[2]) - if ftype < 0 or ftype > 1: - print("Invalid ftype: " + str(ftype)) - sys.exit(1) - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" +def load_hparams_and_tokenizer(dir_model): -if os.path.exists(fname_out): - print(f"Skip conversion, it already exists: {fname_out}") - sys.exit(0) + fname_hparams = f"{dir_model}/params.json" + fname_tokenizer = f"{dir_model}/../tokenizer.model" -with open(fname_hparams, "r") as f: - hparams = json.load(f) + with open(fname_hparams, "r") as f: + hparams = json.load(f) + print(hparams) -tokenizer = SentencePieceProcessor(fname_tokenizer) + tokenizer = SentencePieceProcessor(fname_tokenizer) + hparams.update({"vocab_size": tokenizer.vocab_size()}) -hparams.update({"vocab_size": tokenizer.vocab_size()}) + return hparams, tokenizer -n_parts = get_n_parts(hparams["dim"]) +def write_header(fout, hparams, ftype): -print(hparams) -print('n_parts = ', n_parts) + keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"] + values = [ + 0x67676d66, # magic: ggml in hex + 1, # file version + *[hparams[key] for key in keys], + hparams["dim"] // hparams["n_heads"], # rot (obsolete) + ftype + ] + fout.write(struct.pack("i" * len(values), *values)) -for p in range(n_parts): - print('Processing part ', p) +def write_tokens(fout, tokenizer): - #fname_model = sys.argv[1] + "/consolidated.00.pth" - fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth" - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" - if (p > 0): - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p) - - model = torch.load(fname_model, map_location="cpu") - - fout = open(fname_out, "wb") - - fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex - fout.write(struct.pack("i", hparams["vocab_size"])) - fout.write(struct.pack("i", hparams["dim"])) - fout.write(struct.pack("i", hparams["multiple_of"])) - fout.write(struct.pack("i", hparams["n_heads"])) - fout.write(struct.pack("i", hparams["n_layers"])) - fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete) - fout.write(struct.pack("i", ftype)) - - # Is this correct?? for i in range(tokenizer.vocab_size()): if tokenizer.is_unknown(i): - # "" token (translated as ??) text = " \u2047 ".encode("utf-8") - fout.write(struct.pack("i", len(text))) - fout.write(text) elif tokenizer.is_control(i): - # ""/"" tokens - fout.write(struct.pack("i", 0)) + text = b"" elif tokenizer.is_byte(i): - # "" tokens (which may be invalid UTF-8) piece = tokenizer.id_to_piece(i) if len(piece) != 6: - print("Invalid token: " + piece) + print(f"Invalid token: {piece}") sys.exit(1) byte_value = int(piece[3:-1], 16) - fout.write(struct.pack("i", 1)) - fout.write(struct.pack("B", byte_value)) + text = struct.pack("B", byte_value) else: - # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces. text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - fout.write(struct.pack("i", len(text))) - fout.write(text) + fout.write(struct.pack("i", len(text))) + fout.write(text) + fout.write(struct.pack("f", tokenizer.get_score(i))) - for k, v in model.items(): - name = k - shape = v.shape +def process_and_write_variables(fout, model, ftype): - # skip layers.X.attention.inner_attention.rope.freqs - if name[-5:] == "freqs": + for name, datao in model.items(): + + if name.endswith("freqs"): continue - print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) + shape = datao.shape - #data = tf.train.load_variable(dir_model, name).squeeze() - data = v.numpy().squeeze() - n_dims = len(data.shape); + print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}") - # for efficiency - transpose some matrices - # "model/h.*/attn/c_attn/w" - # "model/h.*/attn/c_proj/w" - # "model/h.*/mlp/c_fc/w" - # "model/h.*/mlp/c_proj/w" - #if name[-14:] == "/attn/c_attn/w" or \ - # name[-14:] == "/attn/c_proj/w" or \ - # name[-11:] == "/mlp/c_fc/w" or \ - # name[-13:] == "/mlp/c_proj/w": - # print(" Transposing") - # data = data.transpose() - - dshape = data.shape + data = datao.numpy().squeeze() + n_dims = len(shape) # default type is fp16 ftype_cur = 1 @@ -164,18 +111,40 @@ for p in range(n_parts): # header sname = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) - for i in range(n_dims): - fout.write(struct.pack("i", dshape[n_dims - 1 - i])) - fout.write(sname); + fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur)) + for dim in reversed(data.shape): + fout.write(struct.pack("i", dim)) + fout.write(sname) - # data + # data output to file data.tofile(fout) - # I hope this deallocates the memory .. - model = None +def main(): - fout.close() + args = parse_args() + dir_model = args.dir_model + ftype = args.ftype + ftype_str = ["f32", "f16"] - print("Done. Output file: " + fname_out + ", (part ", p, ")") - print("") + hparams, tokenizer = load_hparams_and_tokenizer(dir_model) + n_parts = get_n_parts(hparams["dim"]) + + for p in range(n_parts): + + print(f"Processing part {p}\n") + + fname_model = f"{dir_model}/consolidated.0{p}.pth" + fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}" + + model = torch.load(fname_model, map_location="cpu") + + with open(fname_out, "wb") as fout: + write_header(fout, hparams, ftype) + write_tokens(fout, tokenizer) + process_and_write_variables(fout, model, ftype) + + del model + print(f"Done. Output file: {fname_out}, (part {p})\n") + +if __name__ == "__main__": + main() diff --git a/expose.cpp b/expose.cpp index a72506a8b..3b2eb07ef 100644 --- a/expose.cpp +++ b/expose.cpp @@ -55,7 +55,7 @@ extern "C" { int n_parts_overwrite = inputs.n_parts_overwrite; - if (!llama_model_load(api_params.model, api_model, api_vocab, api_params.n_ctx, n_parts_overwrite)) { + if (!llama_model_load(api_params.model, api_model, api_vocab, api_params.n_ctx, GGML_TYPE_F16, n_parts_overwrite)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, api_params.model.c_str()); return false; } @@ -101,7 +101,6 @@ extern "C" { if(reset_state) { api_params.prompt.insert(0, 1, ' '); - mem_per_token = 0; } // tokenize the prompt std::vector embd_inp = ::llama_tokenize(api_vocab, api_params.prompt, true); @@ -164,7 +163,7 @@ extern "C" { { // set the logit of the eos token (2) to zero to avoid sampling it - api_logits[api_logits.size() - n_vocab + 2] = 0; + api_logits[api_logits.size() - n_vocab + EOS_TOKEN_ID] = 0; //set logits of opening square bracket to zero. api_logits[api_logits.size() - n_vocab + 518] = 0; api_logits[api_logits.size() - n_vocab + 29961] = 0; diff --git a/llamacpp.dll b/llamacpp.dll index 04c8b260b..6ded807cf 100644 Binary files a/llamacpp.dll and b/llamacpp.dll differ diff --git a/main.cpp b/main.cpp index 4c9b2ff9b..005aa22cc 100644 --- a/main.cpp +++ b/main.cpp @@ -3,10 +3,12 @@ #include "utils.h" #include +#include #include #include #include #include +#include #include #include #include @@ -27,6 +29,8 @@ #define ANSI_COLOR_RESET "\x1b[0m" #define ANSI_BOLD "\x1b[1m" +static const int EOS_TOKEN_ID = 2; + // determine number of model parts based on the dimension static const std::map LLAMA_N_PARTS = { { 4096, 1 }, @@ -86,30 +90,40 @@ struct llama_model { }; // load the model's weights from a file -bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, int n_parts_overwrite=-1) { +bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32, int n_parts_overwrite=-1) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); + std::vector f_buf(1024*1024); - FILE *fin = fopen(fname.data(), "rb"); - + auto fin = std::ifstream(fname, std::ios::binary); + fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); if (!fin) { fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); return false; } - // Having a large buffer helps to accelerate load considerably (old buffer was 1024 * 1024). - // Though I am not sure if it's okay for edge devices like Raspberry Pi. - std::vector f_buf(128 * 1024 * 1024); - setvbuf(fin, f_buf.data(), _IOFBF, f_buf.size()); - // verify magic { uint32_t magic; - fread((char *) &magic, 1, sizeof(magic), fin); - if (magic != 0x67676d6c) { + fin.read((char *) &magic, sizeof(magic)); + if (magic == 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n", + __func__, fname.c_str()); + return false; + } + if (magic != 0x67676d66) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } + + uint32_t format_version; + fin.read((char *) &format_version, sizeof(format_version)); + + if (format_version != 1) { + fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ")\n", + __func__, fname.c_str(), format_version); + return false; + } } int n_ff = 0; @@ -119,14 +133,14 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab { auto & hparams = model.hparams; - fread((char *) &hparams.n_vocab, 1, sizeof(hparams.n_vocab), fin); + fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); //fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); - fread((char *) &hparams.n_embd, 1, sizeof(hparams.n_embd), fin); - fread((char *) &hparams.n_mult, 1, sizeof(hparams.n_mult), fin); - fread((char *) &hparams.n_head, 1, sizeof(hparams.n_head), fin); - fread((char *) &hparams.n_layer, 1, sizeof(hparams.n_layer), fin); - fread((char *) &hparams.n_rot, 1, sizeof(hparams.n_rot), fin); - fread((char *) &hparams.f16, 1, sizeof(hparams.f16), fin); + fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + fin.read((char *) &hparams.n_mult, sizeof(hparams.n_mult)); + fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); + fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + fin.read((char *) &hparams.f16, sizeof(hparams.f16)); hparams.n_ctx = n_ctx; @@ -154,13 +168,17 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab std::string word; for (int i = 0; i < model.hparams.n_vocab; i++) { uint32_t len; - fread((char *) &len, 1, sizeof(len), fin); + fin.read((char *) &len, sizeof(len)); word.resize(len); - fread((char *) word.data(), 1, len, fin); + fin.read((char *) word.data(), len); + + float score; + fin.read((char *) &score, sizeof(score)); vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; + vocab.score[i] = score; //if (i < 30000) { // fprintf(stderr, "%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); @@ -184,8 +202,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab } } - const ggml_type wtype2 = GGML_TYPE_F32; - auto & ctx = model.ctx; size_t ctx_size = 0; @@ -217,8 +233,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2 ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3 - ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_k - ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_v + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v ctx_size += (5 + 10*n_layer)*256; // object overhead @@ -245,7 +261,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; model.layers.resize(n_layer); @@ -304,17 +319,17 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab const int n_mem = n_layer*n_ctx; const int n_elements = n_embd*n_mem; - model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_k = ggml_new_tensor_1d(ctx, memory_type, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, memory_type, n_elements); const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); fprintf(stderr, "%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); } - const size_t file_offset = ftell(fin); + const size_t file_offset = fin.tellg(); - fclose(fin); + fin.close(); std::vector tmp; @@ -329,9 +344,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str()); - fin = fopen(fname_part.data(), "rb"); - setvbuf(fin, f_buf.data(), _IOFBF, f_buf.size()); - fseek(fin, file_offset, SEEK_CUR); + fin = std::ifstream(fname_part, std::ios::binary); + fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); + fin.seekg(file_offset); + // load weights { int n_tensors = 0; @@ -344,24 +360,23 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab int32_t length; int32_t ftype; - fread(reinterpret_cast(&n_dims), 1, sizeof(n_dims), fin); - fread(reinterpret_cast(&length), 1, sizeof(length), fin); - fread(reinterpret_cast(&ftype), 1, sizeof(ftype), fin); + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); - - if (feof(fin)) { + if (fin.eof()) { break; } int32_t nelements = 1; int32_t ne[2] = { 1, 1 }; for (int i = 0; i < n_dims; ++i) { - fread(reinterpret_cast(&ne[i]), 1, sizeof(ne[i]), fin); + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); nelements *= ne[i]; } std::string name(length, 0); - fread(&name[0], 1, length, fin); + fin.read(&name[0], length); if (model.tensors.find(name.data()) == model.tensors.end()) { fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); @@ -463,9 +478,9 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab } if (part_id == 0) { - fread(reinterpret_cast(tensor->data), 1, ggml_nbytes(tensor), fin); + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); } else { - fseek(fin, ggml_nbytes(tensor), SEEK_CUR); + fin.seekg(ggml_nbytes(tensor), std::ios::cur); } total_size += ggml_nbytes(tensor); @@ -485,7 +500,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab for (int i1 = 0; i1 < ne[1]; ++i1) { const size_t offset_row = i1*row_size; const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); - fread(reinterpret_cast(tensor->data) + offset, 1, row_size/n_parts, fin); + fin.read(reinterpret_cast(tensor->data) + offset, row_size/n_parts); } } else { const int np1 = ne[1]; @@ -494,7 +509,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab for (int i1 = 0; i1 < ne[1]; ++i1) { const size_t offset_row = (i1 + part_id*np1)*row_size; - fread(reinterpret_cast(tensor->data) + offset_row, 1, row_size, fin); + fin.read(reinterpret_cast(tensor->data) + offset_row, row_size); } } @@ -513,11 +528,12 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); } - fclose(fin); + fin.close(); } return true; } + // evaluate the transformer // // - model: the model @@ -546,9 +562,9 @@ bool llama_eval( const int n_vocab = hparams.n_vocab; const int n_rot = hparams.n_embd/hparams.n_head; - const int d_key = n_embd/n_head; - - static size_t buf_size = (size_t)hparams.n_ctx*1024*1024; + // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case + // static size_t buf_size = hparams.n_ctx*1024*1024; + static size_t buf_size = 512u*1024*1024; static void * buf = malloc(buf_size); if (mem_per_token > 0 && mem_per_token*N > buf_size) { @@ -757,6 +773,7 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { printf(ANSI_COLOR_RESET); + printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; @@ -797,7 +814,7 @@ int main(int argc, char ** argv) { if (gpt_params_parse(argc, argv, params) == false) { return 1; } - + if (params.n_ctx > 2048) { fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);" "expect poor results\n", __func__, params.n_ctx); @@ -810,7 +827,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); - if (params.prompt.empty()) { + if (params.random_prompt) { params.prompt = gpt_random_prompt(rng); } @@ -824,8 +841,9 @@ int main(int argc, char ** argv) { // load the model { + const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; const int64_t t_start_us = ggml_time_us(); - if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { + if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } @@ -854,8 +872,27 @@ int main(int argc, char ** argv) { params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); + // prefix & suffix for instruct mode + const std::vector inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true); + const std::vector inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false); + + // in instruct mode, we inject a prefix and a suffix to each input by the user + if (params.instruct) { + params.interactive = true; + params.antiprompt.push_back("### Instruction:\n\n"); + } + // tokenize the reverse prompt - std::vector antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false); + std::vector> antipromptv_inp; + + for (auto antiprompt : params.antiprompt) { + antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false)); + } + + // enable interactive mode if reverse prompt is specified + if (antipromptv_inp.size() != 0) { + params.interactive = true; + } fprintf(stderr, "\n"); fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); @@ -877,13 +914,16 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: interactive mode on.\n", __func__); - if(antiprompt_inp.size()) { - fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str()); - fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); - for (int i = 0; i < (int) antiprompt_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); + if(antipromptv_inp.size()) { + for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) { + auto antiprompt_inp = antipromptv_inp.at(apindex); + fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str()); + fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); + for (int i = 0; i < (int) antiprompt_inp.size(); i++) { + fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); + } + fprintf(stderr, "\n"); } - fprintf(stderr, "\n"); } } fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); @@ -899,31 +939,27 @@ int main(int argc, char ** argv) { std::vector last_n_tokens(last_n_size); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); - if (params.interactive) { fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n"); + " - If you want to submit another line, end your input in '\\'.\n\n"); + is_interacting = true; } - int remaining_tokens = params.n_predict; int input_consumed = 0; bool input_noecho = false; - // prompt user immediately after the starting prompt has been loaded - if (params.interactive_start) { - is_interacting = true; - } + int remaining_tokens = params.n_predict; // set the color for the prompt which will be output initially if (params.use_color) { printf(ANSI_COLOR_YELLOW); } - while (remaining_tokens > 0) { + while (remaining_tokens > 0 || params.interactive) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); @@ -953,8 +989,10 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); - // set the logit of the eos token (2) to zero to avoid sampling it - logits[logits.size() - n_vocab + 2] = 0; + if (params.ignore_eos) { + // set the logit of the eos token to zero to avoid sampling it + logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0; + } id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); @@ -979,15 +1017,10 @@ int main(int argc, char ** argv) { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(embd_inp[input_consumed]); ++input_consumed; - if (embd.size() > params.n_batch) { + if ((int) embd.size() >= params.n_batch) { break; } } - - // reset color to default if we there is no pending user input - if (!input_noecho && params.use_color && embd_inp.size() == input_consumed) { - printf(ANSI_COLOR_RESET); - } } // display text @@ -997,56 +1030,74 @@ int main(int argc, char ** argv) { } fflush(stdout); } + // reset color to default if we there is no pending user input + if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) { + printf(ANSI_COLOR_RESET); + } // in interactive mode, and not currently processing queued inputs; // check if we should prompt the user for more if (params.interactive && embd_inp.size() <= input_consumed) { // check for reverse prompt - if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { - // reverse prompt found - is_interacting = true; + for (auto antiprompt_inp : antipromptv_inp) { + if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { + // reverse prompt found + is_interacting = true; + break; + } } if (is_interacting) { - // currently being interactive - bool another_line=true; - while (another_line) { - fflush(stdout); - char buf[256] = {0}; - int n_read; - if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); - if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) { - // presumable empty line, consume the newline - std::ignore = scanf("%*c"); - n_read=0; - } - if(params.use_color) printf(ANSI_COLOR_RESET); + if (params.instruct) { + input_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); - if (n_read > 0 && buf[n_read-1]=='\\') { - another_line = true; - buf[n_read-1] = '\n'; - buf[n_read] = 0; - } else { - another_line = false; - buf[n_read] = '\n'; - buf[n_read+1] = 0; - } - - std::vector line_inp = ::llama_tokenize(vocab, buf, false); - embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); - - remaining_tokens -= line_inp.size(); - - input_noecho = true; // do not echo this again + printf("\n> "); } - is_interacting = false; + // currently being interactive + if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); + std::string buffer; + std::string line; + bool another_line = true; + do { + std::getline(std::cin, line); + if (line.empty() || line.back() != '\\') { + another_line = false; + } else { + line.pop_back(); // Remove the continue character + } + buffer += line + '\n'; // Append the line to the result + } while (another_line); + if (params.use_color) printf(ANSI_COLOR_RESET); + + std::vector line_inp = ::llama_tokenize(vocab, buffer, false); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + + if (params.instruct) { + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + + remaining_tokens -= line_inp.size(); + + input_noecho = true; // do not echo this again } + is_interacting = false; } // end of text token - if (embd.back() == 2) { - fprintf(stderr, " [end of text]\n"); - break; + if (embd.back() == EOS_TOKEN_ID) { + if (params.interactive) { + is_interacting = true; + } else { + fprintf(stderr, " [end of text]\n"); + break; + } + } + + // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. + if (params.interactive && remaining_tokens <= 0) { + remaining_tokens = params.n_predict; + is_interacting = true; } } diff --git a/main.exe b/main.exe index 22210d3a0..ab68d1df8 100644 Binary files a/main.exe and b/main.exe differ diff --git a/prompts/alpaca.txt b/prompts/alpaca.txt new file mode 100644 index 000000000..2224bdeb0 --- /dev/null +++ b/prompts/alpaca.txt @@ -0,0 +1 @@ +Below is an instruction that describes a task. Write a response that appropriately completes the request. diff --git a/prompts/chat-with-bob.txt b/prompts/chat-with-bob.txt new file mode 100644 index 000000000..009da39ae --- /dev/null +++ b/prompts/chat-with-bob.txt @@ -0,0 +1,7 @@ +Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. + +User: Hello, Bob. +Bob: Hello. How may I help you today? +User: Please tell me the largest city in Europe. +Bob: Sure. The largest city in Europe is Moscow, the capital of Russia. +User: diff --git a/quantize.cpp b/quantize.cpp index 14c7b277a..166e9163a 100644 --- a/quantize.cpp +++ b/quantize.cpp @@ -3,6 +3,7 @@ #include "utils.h" #include +#include #include #include #include @@ -63,12 +64,28 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna { uint32_t magic; finp.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { + if (magic == 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n", + __func__, fname_inp.c_str()); + return false; + } + if (magic != 0x67676d66) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); return false; } fout.write((char *) &magic, sizeof(magic)); + + uint32_t format_version; + finp.read((char *) &format_version, sizeof(format_version)); + + if (format_version != 1) { + fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ")\n", + __func__, fname_inp.c_str(), format_version); + return false; + } + + fout.write((char *) &format_version, sizeof(format_version)); } llama_hparams hparams; @@ -122,8 +139,13 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna finp.read ((char *) word.data(), len); fout.write((char *) word.data(), len); + float score; + finp.read ((char *) &score, sizeof(score)); + fout.write((char *) &score, sizeof(score)); + vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; + vocab.score[i] = score; } } diff --git a/quantize.exe b/quantize.exe index 19e6a51d4..41c0cc9c8 100644 Binary files a/quantize.exe and b/quantize.exe differ diff --git a/quantize.py b/quantize.py new file mode 100644 index 000000000..6320b0a26 --- /dev/null +++ b/quantize.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +"""Script to execute the "quantize" script on a given set of models.""" + +import subprocess +import argparse +import glob +import sys +import os + + +def main(): + """Update the quantize binary name depending on the platform and parse + the command line arguments and execute the script. + """ + + if "linux" in sys.platform or "darwin" in sys.platform: + quantize_script_binary = "quantize" + + elif "win32" in sys.platform or "cygwin" in sys.platform: + quantize_script_binary = "quantize.exe" + + else: + print("WARNING: Unknown platform. Assuming a UNIX-like OS.\n") + quantize_script_binary = "quantize" + + parser = argparse.ArgumentParser( + prog='python3 quantize.py', + description='This script quantizes the given models by applying the ' + f'"{quantize_script_binary}" script on them.' + ) + parser.add_argument( + 'models', nargs='+', choices=('7B', '13B', '30B', '65B'), + help='The models to quantize.' + ) + parser.add_argument( + '-r', '--remove-16', action='store_true', dest='remove_f16', + help='Remove the f16 model after quantizing it.' + ) + parser.add_argument( + '-m', '--models-path', dest='models_path', + default=os.path.join(os.getcwd(), "models"), + help='Specify the directory where the models are located.' + ) + parser.add_argument( + '-q', '--quantize-script-path', dest='quantize_script_path', + default=os.path.join(os.getcwd(), quantize_script_binary), + help='Specify the path to the "quantize" script.' + ) + + # TODO: Revise this code + # parser.add_argument( + # '-t', '--threads', dest='threads', type='int', + # default=os.cpu_count(), + # help='Specify the number of threads to use to quantize many models at ' + # 'once. Defaults to os.cpu_count().' + # ) + + args = parser.parse_args() + + if not os.path.isfile(args.quantize_script_path): + print( + f'The "{quantize_script_binary}" script was not found in the ' + "current location.\nIf you want to use it from another location, " + "set the --quantize-script-path argument from the command line." + ) + sys.exit(1) + + for model in args.models: + # The model is separated in various parts + # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...) + f16_model_path_base = os.path.join( + args.models_path, model, "ggml-model-f16.bin" + ) + + f16_model_parts_paths = map( + lambda filename: os.path.join(f16_model_path_base, filename), + glob.glob(f"{f16_model_path_base}*") + ) + + for f16_model_part_path in f16_model_parts_paths: + if not os.path.isfile(f16_model_part_path): + print( + f"The f16 model {os.path.basename(f16_model_part_path)} " + f"was not found in {args.models_path}{os.path.sep}{model}" + ". If you want to use it from another location, set the " + "--models-path argument from the command line." + ) + sys.exit(1) + + __run_quantize_script( + args.quantize_script_path, f16_model_part_path + ) + + if args.remove_f16: + os.remove(f16_model_part_path) + + +# This was extracted to a top-level function for parallelization, if +# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406 + +def __run_quantize_script(script_path, f16_model_part_path): + """Run the quantize script specifying the path to it and the path to the + f16 model to quantize. + """ + + new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0") + subprocess.run( + [script_path, f16_model_part_path, new_quantized_model_path, "2"], + check=True + ) + + +if __name__ == "__main__": + try: + main() + + except subprocess.CalledProcessError: + print("\nAn error ocurred while trying to quantize the models.") + sys.exit(1) + + except KeyboardInterrupt: + sys.exit(0) + + else: + print("\nSuccesfully quantized all models.") diff --git a/quantize.sh b/quantize.sh deleted file mode 100755 index 6194649b3..000000000 --- a/quantize.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env bash - -if ! [[ "$1" =~ ^[0-9]{1,2}B$ ]]; then - echo - echo "Usage: quantize.sh 7B|13B|30B|65B [--remove-f16]" - echo - exit 1 -fi - -for i in `ls models/$1/ggml-model-f16.bin*`; do - ./quantize "$i" "${i/f16/q4_0}" 2 - if [[ "$2" == "--remove-f16" ]]; then - rm "$i" - fi -done diff --git a/utils.cpp b/utils.cpp index efa2e3c35..188f114e9 100644 --- a/utils.cpp +++ b/utils.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -38,19 +39,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; } else if (arg == "-f" || arg == "--file") { - std::ifstream file(argv[++i]); - - std::copy(std::istreambuf_iterator(file), - std::istreambuf_iterator(), - back_inserter(params.prompt)); - + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (params.prompt.back() == '\n') { + params.prompt.pop_back(); + } } else if (arg == "-n" || arg == "--n_predict") { params.n_predict = std::stoi(argv[++i]); } else if (arg == "--top_k") { params.top_k = std::stoi(argv[++i]); } else if (arg == "-c" || arg == "--ctx_size") { params.n_ctx = std::stoi(argv[++i]); + } else if (arg == "--memory_f16") { + params.memory_f16 = true; } else if (arg == "--top_p") { params.top_p = std::stof(argv[++i]); } else if (arg == "--temp") { @@ -65,16 +66,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.model = argv[++i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; - } else if (arg == "--interactive-start") { - params.interactive = true; - params.interactive_start = true; + } else if (arg == "-ins" || arg == "--instruct") { + params.instruct = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "-r" || arg == "--reverse-prompt") { - params.antiprompt = argv[++i]; + params.antiprompt.push_back(argv[++i]); + } else if (arg == "--ignore-eos") { + params.ignore_eos = true; } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); + } else if (arg == "--random-prompt") { + params.random_prompt = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, params); @@ -85,20 +89,22 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return true; } -void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { +void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -i, --interactive run in interactive mode\n"); - fprintf(stderr, " --interactive-start run in interactive mode and poll user input at startup\n"); + fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); - fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT\n"); + fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT (can be\n"); + fprintf(stderr, " specified more than once for multiple prompts).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); - fprintf(stderr, " prompt to start generation with (default: random)\n"); + fprintf(stderr, " prompt to start generation with (default: empty)\n"); + fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); @@ -107,6 +113,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); + fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n"); + fprintf(stderr, " --memory_f16 use f16 instead of f32 for memory key+value\n"); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -m FNAME, --model FNAME\n"); @@ -287,58 +295,146 @@ std::vector gpt_tokenize(const gpt_vocab & vocab, const std::stri return tokens; } -// TODO: Calculate this constant from the vocabulary -#define MAX_TOKEN_LEN 18 -// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece -std::vector llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) { - std::vector res; - std::vector score; - std::vector prev; - int len = text.length(); +static size_t utf8_len(char src) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(src) >> 4; + return lookup[highbits]; +} - score.resize(len + 1); - prev.resize(len + 1); +struct llama_sp_symbol { + using index = int; + index prev; + index next; + std::string_view text; +}; - // Forward pass - for (int i = 0; i < len; i++) { - int max_len = std::min(len - i, MAX_TOKEN_LEN); - for (int sub_len = 1; sub_len <= max_len; sub_len++) { - auto sub = text.substr(i, sub_len); - auto token = vocab.token_to_id.find(sub); - if (token != vocab.token_to_id.end()) { - int token_score = sub.length() * sub.length(); - int local_score = score[i] + token_score; - int next = i + sub_len; - if (score[next] < local_score) { - score[next] = local_score; - prev[next] = (*token).second; +struct llama_sp_bigram { + struct comparator { + bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) { + return (l.score < r.score) || (l.score == r.score && l.left > r.left); + } + }; + using queue_storage = std::vector; + using queue = std::priority_queue; + llama_sp_symbol::index left; + llama_sp_symbol::index right; + float score; + size_t size; +}; + +struct llama_tokenizer { + llama_tokenizer(const gpt_vocab & vocab): vocab_(vocab) {} + + void tokenize(std::string_view text, std::vector & output) { + // split string into utf8 chars + int index = 0; + while (!text.empty()) { + llama_sp_symbol sym; + size_t char_len = std::min(text.size(), utf8_len(text.data()[0])); + sym.text = std::string_view(text.data(), char_len); + sym.prev = index - 1; + text.remove_prefix(char_len); + sym.next = text.empty() ? -1 : index + 1; + index++; + symbols_.emplace_back(std::move(sym)); + } + + // seed the work queue with all possible 2-character tokens. + for (size_t i = 1; i < symbols_.size(); ++i) { + try_add_bigram(i - 1, i); + } + + // keep substituting the highest frequency pairs for as long as we can. + while (!work_queue_.empty()) { + auto bigram = work_queue_.top(); + work_queue_.pop(); + + auto & left_sym = symbols_[bigram.left]; + auto & right_sym = symbols_[bigram.right]; + + // if one of the symbols already got merged, skip it. + if (left_sym.text.empty() || right_sym.text.empty() || + left_sym.text.size() + right_sym.text.size() != bigram.size) { + continue; + } + + // merge the right sym into the left one + left_sym.text = std::string_view(left_sym.text.data(), left_sym.text.size() + right_sym.text.size()); + right_sym.text = std::string_view(""); + + // remove the right sym from the chain + left_sym.next = right_sym.next; + if (right_sym.next >= 0) { + symbols_[right_sym.next].prev = bigram.left; + } + + // find more substitutions + try_add_bigram(left_sym.prev, bigram.left); + try_add_bigram(bigram.left, left_sym.next); + } + + for (int i = 0; i != -1; i = symbols_[i].next) { + auto& symbol = symbols_[i]; + auto token = vocab_.token_to_id.find(std::string(symbol.text)); + + if (token == vocab_.token_to_id.end()) { + // output any symbols that did not form tokens as bytes. + for (int j = 0; j < symbol.text.size(); ++j) { + gpt_vocab::id token_id = static_cast(symbol.text[j]) + 3; + output.push_back(token_id); } + } else { + output.push_back((*token).second); } } } - // Backward pass - int i = len; - while (i > 0) { - gpt_vocab::id token_id = prev[i]; - if (token_id == 0) { - // TODO: Return error or something more meaningful - printf("failed to tokenize string!\n"); - break; +private: + void try_add_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; } - res.push_back(token_id); - auto token = (*vocab.id_to_token.find(token_id)).second; - i -= token.length(); + + std::string_view text(symbols_[left].text.data(), symbols_[left].text.size() + symbols_[right].text.size()); + auto token = vocab_.token_to_id.find(std::string(text)); + + if (token == vocab_.token_to_id.end()) { + return; + } + + auto score = vocab_.score.find((*token).second); + + if (score == vocab_.score.end()) { + return; + } + + llama_sp_bigram bigram; + bigram.left = left; + bigram.right = right; + bigram.score = (*score).second; + bigram.size = text.size(); + work_queue_.push(bigram); + } + + const gpt_vocab & vocab_; + std::vector symbols_; + llama_sp_bigram::queue work_queue_; +}; + +std::vector llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos) { + llama_tokenizer tokenizer(vocab); + std::vector output; + + if (text.size() == 0) { + return output; } if (bos) { - res.push_back(1); // TODO: replace with vocab.bos + output.push_back(1); } - // Pieces are in reverse order so correct that - std::reverse(res.begin(), res.end()); - - return res; + tokenizer.tokenize(text, output); + return output; } bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { @@ -398,7 +494,7 @@ gpt_vocab::id llama_sample_top_p_top_k( logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); } else { logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); - } + } } else { logits_id.push_back(std::make_pair(logits[i]*scale, i)); } diff --git a/utils.h b/utils.h index c1a8498a7..b3a0f4724 100644 --- a/utils.h +++ b/utils.h @@ -18,6 +18,7 @@ struct gpt_params { int32_t n_predict = 128; // new tokens to predict int32_t repeat_last_n = 64; // last n tokens to penalize int32_t n_ctx = 512; //context size + bool memory_f16 = false; // use f16 instead of f32 for memory kv // sampling parameters int32_t top_k = 40; @@ -27,14 +28,18 @@ struct gpt_params { int32_t n_batch = 8; // batch size for prompt processing - std::string model = "models/lamma-7B/ggml-model.bin"; // model path - std::string prompt; + std::string model = "models/lamma-7B/ggml-model.bin"; // model path + std::string prompt = ""; + + bool random_prompt = false; bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode bool interactive_start = false; // reverse prompt immediately - std::string antiprompt = ""; // string upon seeing which more user input is prompted + std::vector antiprompt; // string upon seeing which more user input is prompted + bool instruct = false; // instruction mode (used for Alpaca models) + bool ignore_eos = false; // do not stop generating after eos }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); @@ -53,6 +58,7 @@ struct gpt_vocab { std::map token_to_id; std::map id_to_token; + std::map score; }; void replace(std::string & str, const std::string & needle, const std::string & replacement); @@ -74,7 +80,7 @@ std::vector gpt_tokenize(const gpt_vocab & vocab, const std::stri // TODO: this is probably wrong, but I cannot figure out how this tokenizer works .. // ref: https://github.com/google/sentencepiece -std::vector llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos); +std::vector llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos); // load the tokens from encoder.json bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);