Merge branch 'ggerganov:master' into master

This commit is contained in:
OvJat 2023-03-20 11:36:24 +08:00 committed by GitHub
commit 564cdf8f4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 397 additions and 227 deletions

View file

@ -7,7 +7,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
**Hot topics:** **Hot topics:**
- RMSNorm implementation / fixes: https://github.com/ggerganov/llama.cpp/issues/173 - [Added Alpaca support](https://github.com/ggerganov/llama.cpp#instruction-mode-with-alpaca)
- Cache input prompts for faster initialization: https://github.com/ggerganov/llama.cpp/issues/64 - Cache input prompts for faster initialization: https://github.com/ggerganov/llama.cpp/issues/64
- Create a `llama.cpp` logo: https://github.com/ggerganov/llama.cpp/issues/105 - Create a `llama.cpp` logo: https://github.com/ggerganov/llama.cpp/issues/105
@ -147,7 +147,7 @@ python3 -m pip install torch numpy sentencepiece
python3 convert-pth-to-ggml.py models/7B/ 1 python3 convert-pth-to-ggml.py models/7B/ 1
# quantize the model to 4-bits # quantize the model to 4-bits
./quantize.sh 7B python3 quantize.py 7B
# run the inference # run the inference
./main -m ./models/7B/ggml-model-q4_0.bin -n 128 ./main -m ./models/7B/ggml-model-q4_0.bin -n 128
@ -176,21 +176,51 @@ In this mode, you can always interrupt generation by pressing Ctrl+C and enter o
Here is an example few-shot interaction, invoked with the command Here is an example few-shot interaction, invoked with the command
``` ```
./main -m ./models/13B/ggml-model-q4_0.bin -n 256 --repeat_penalty 1.0 --color -i -r "User:" \ ./main -m ./models/13B/ggml-model-q4_0.bin -n 256 --repeat_penalty 1.0 --color -i -r "User:" -f prompts/chat-with-bob.txt
-p \
"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia.
User:"
``` ```
Note the use of `--color` to distinguish between user input and generated text. Note the use of `--color` to distinguish between user input and generated text.
![image](https://user-images.githubusercontent.com/1991296/224575029-2af3c7dc-5a65-4f64-a6bb-517a532aea38.png) ![image](https://user-images.githubusercontent.com/1991296/224575029-2af3c7dc-5a65-4f64-a6bb-517a532aea38.png)
### Instruction mode with Alpaca
First, download the `ggml` Alpaca model into the `./models` folder:
```
# use one of these
# NOTE: these are copied from the alpaca.cpp repo - not sure how long these will work
# TODO: add a script to simplify the download
curl -o ggml-alpaca-7b-q4.bin -C - https://gateway.estuary.tech/gw/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC
curl -o ggml-alpaca-7b-q4.bin -C - https://ipfs.io/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC
curl -o ggml-alpaca-7b-q4.bin -C - https://cloudflare-ipfs.com/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC
```
Now run the `main` tool like this:
```
./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins
```
Sample run:
```
== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to LLaMa.
- If you want to submit another line, end your input in '\'.
Below is an instruction that describes a task. Write a response that appropriately completes the request.
> How many letters are there in the English alphabet?
There 26 letters in the English Alphabet
> What is the most common way of transportation in Amsterdam?
The majority (54%) are using public transit. This includes buses, trams and metros with over 100 lines throughout the city which make it very accessible for tourists to navigate around town as well as locals who commute by tram or metro on a daily basis
> List 5 words that start with "ca".
cadaver, cauliflower, cabbage (vegetable), catalpa (tree) and Cailleach.
>
```
### Android ### Android
You can easily run `llama.cpp` on Android device with [termux](https://play.google.com/store/apps/details?id=com.termux). You can easily run `llama.cpp` on Android device with [termux](https://play.google.com/store/apps/details?id=com.termux).

6
alpaca.sh Executable file
View file

@ -0,0 +1,6 @@
#!/bin/bash
#
# Temporary script - will be removed in the future
#
./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins --top_k 10000 --temp 0.96 --repeat_penalty 1 -t 7

View file

@ -16,7 +16,7 @@
# At the start of the ggml file we write the model parameters # At the start of the ggml file we write the model parameters
# and vocabulary. # and vocabulary.
# #
import os import argparse
import sys import sys
import json import json
import struct import struct
@ -24,136 +24,81 @@ import numpy as np
import torch import torch
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
if len(sys.argv) < 3: def parse_args():
print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)
# output in the same directory as the model parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
dir_model = sys.argv[1] parser.add_argument('dir_model', help='directory containing the model checkpoint')
parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)')
fname_hparams = sys.argv[1] + "/params.json" return parser.parse_args()
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
def get_n_parts(dim): def get_n_parts(dim):
if dim == 4096:
return 1 mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
elif dim == 5120: n_parts = mappings.get(dim)
return 2 if n_parts is None:
elif dim == 6656: print(f"Invalid dim: {dim}")
return 4
elif dim == 8192:
return 8
else:
print("Invalid dim: " + str(dim))
sys.exit(1) sys.exit(1)
# possible data types print(f"n_parts = {n_parts}\n")
# ftype == 0 -> float32 return n_parts
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]
ftype = 1 def load_hparams_and_tokenizer(dir_model):
if len(sys.argv) > 2:
ftype = int(sys.argv[2])
if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))
sys.exit(1)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
if os.path.exists(fname_out): fname_hparams = f"{dir_model}/params.json"
print(f"Skip conversion, it already exists: {fname_out}") fname_tokenizer = f"{dir_model}/../tokenizer.model"
sys.exit(0)
with open(fname_hparams, "r") as f: with open(fname_hparams, "r") as f:
hparams = json.load(f) hparams = json.load(f)
print(hparams)
tokenizer = SentencePieceProcessor(fname_tokenizer) tokenizer = SentencePieceProcessor(fname_tokenizer)
hparams.update({"vocab_size": tokenizer.vocab_size()})
hparams.update({"vocab_size": tokenizer.vocab_size()}) return hparams, tokenizer
n_parts = get_n_parts(hparams["dim"]) def write_header(fout, hparams, ftype):
print(hparams) keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
print('n_parts = ', n_parts) values = [
0x67676d6c, # magic: ggml in hex
*[hparams[key] for key in keys],
hparams["dim"] // hparams["n_heads"], # rot (obsolete)
ftype
]
fout.write(struct.pack("i" * len(values), *values))
for p in range(n_parts): def write_tokens(fout, tokenizer):
print('Processing part ', p)
#fname_model = sys.argv[1] + "/consolidated.00.pth"
fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
if (p > 0):
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
model = torch.load(fname_model, map_location="cpu")
fout = open(fname_out, "wb")
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("i", hparams["dim"]))
fout.write(struct.pack("i", hparams["multiple_of"]))
fout.write(struct.pack("i", hparams["n_heads"]))
fout.write(struct.pack("i", hparams["n_layers"]))
fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
fout.write(struct.pack("i", ftype))
# Is this correct??
for i in range(tokenizer.vocab_size()): for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i): if tokenizer.is_unknown(i):
# "<unk>" token (translated as ??)
text = " \u2047 ".encode("utf-8") text = " \u2047 ".encode("utf-8")
fout.write(struct.pack("i", len(text)))
fout.write(text)
elif tokenizer.is_control(i): elif tokenizer.is_control(i):
# "<s>"/"</s>" tokens text = b""
fout.write(struct.pack("i", 0))
elif tokenizer.is_byte(i): elif tokenizer.is_byte(i):
# "<U+XX>" tokens (which may be invalid UTF-8)
piece = tokenizer.id_to_piece(i) piece = tokenizer.id_to_piece(i)
if len(piece) != 6: if len(piece) != 6:
print("Invalid token: " + piece) print(f"Invalid token: {piece}")
sys.exit(1) sys.exit(1)
byte_value = int(piece[3:-1], 16) byte_value = int(piece[3:-1], 16)
fout.write(struct.pack("i", 1)) text = struct.pack("B", byte_value)
fout.write(struct.pack("B", byte_value))
else: else:
# normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
fout.write(struct.pack("i", len(text))) fout.write(struct.pack("i", len(text)))
fout.write(text) fout.write(text)
for k, v in model.items(): def process_and_write_variables(fout, model, ftype):
name = k
shape = v.shape
# skip layers.X.attention.inner_attention.rope.freqs for name, datao in model.items():
if name[-5:] == "freqs":
if name.endswith("freqs"):
continue continue
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) shape = datao.shape
#data = tf.train.load_variable(dir_model, name).squeeze() print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")
data = v.numpy().squeeze()
n_dims = len(data.shape);
# for efficiency - transpose some matrices data = datao.numpy().squeeze()
# "model/h.*/attn/c_attn/w" n_dims = len(shape)
# "model/h.*/attn/c_proj/w"
# "model/h.*/mlp/c_fc/w"
# "model/h.*/mlp/c_proj/w"
#if name[-14:] == "/attn/c_attn/w" or \
# name[-14:] == "/attn/c_proj/w" or \
# name[-11:] == "/mlp/c_fc/w" or \
# name[-13:] == "/mlp/c_proj/w":
# print(" Transposing")
# data = data.transpose()
dshape = data.shape
# default type is fp16 # default type is fp16
ftype_cur = 1 ftype_cur = 1
@ -164,18 +109,40 @@ for p in range(n_parts):
# header # header
sname = name.encode('utf-8') sname = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
for i in range(n_dims): for dim in reversed(data.shape):
fout.write(struct.pack("i", dshape[n_dims - 1 - i])) fout.write(struct.pack("i", dim))
fout.write(sname); fout.write(sname)
# data # data output to file
data.tofile(fout) data.tofile(fout)
# I hope this deallocates the memory .. def main():
model = None
fout.close() args = parse_args()
dir_model = args.dir_model
ftype = args.ftype
ftype_str = ["f32", "f16"]
print("Done. Output file: " + fname_out + ", (part ", p, ")") hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
print("") n_parts = get_n_parts(hparams["dim"])
for p in range(n_parts):
print(f"Processing part {p}\n")
fname_model = f"{dir_model}/consolidated.0{p}.pth"
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
model = torch.load(fname_model, map_location="cpu")
with open(fname_out, "wb") as fout:
write_header(fout, hparams, ftype)
write_tokens(fout, tokenizer)
process_and_write_variables(fout, model, ftype)
del model
print(f"Done. Output file: {fname_out}, (part {p})\n")
if __name__ == "__main__":
main()

4
ggml.c
View file

@ -5556,7 +5556,7 @@ static void ggml_compute_forward_rms_norm_f32(
const size_t nb2 = dst->nb[2]; const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3]; const size_t nb3 = dst->nb[3];
const ggml_float eps = 1e-5f; // TODO: make this a parameter const ggml_float eps = 1e-6f; // TODO: make this a parameter
// TODO: optimize // TODO: optimize
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
@ -5572,7 +5572,7 @@ static void ggml_compute_forward_rms_norm_f32(
mean /= ne00; mean /= ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float)); memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) { // for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00]; // y[i00] = x[i00];

178
main.cpp
View file

@ -7,6 +7,7 @@
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include <iostream>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
@ -27,6 +28,8 @@
#define ANSI_COLOR_RESET "\x1b[0m" #define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m" #define ANSI_BOLD "\x1b[1m"
static const int EOS_TOKEN_ID = 2;
// determine number of model parts based on the dimension // determine number of model parts based on the dimension
static const std::map<int, int> LLAMA_N_PARTS = { static const std::map<int, int> LLAMA_N_PARTS = {
{ 4096, 1 }, { 4096, 1 },
@ -86,7 +89,7 @@ struct llama_model {
}; };
// load the model's weights from a file // load the model's weights from a file
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) { bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
std::vector<char> f_buf(1024*1024); std::vector<char> f_buf(1024*1024);
@ -176,8 +179,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
} }
} }
const ggml_type wtype2 = GGML_TYPE_F32;
auto & ctx = model.ctx; auto & ctx = model.ctx;
size_t ctx_size = 0; size_t ctx_size = 0;
@ -209,8 +210,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2 ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3 ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v
ctx_size += (5 + 10*n_layer)*256; // object overhead ctx_size += (5 + 10*n_layer)*256; // object overhead
@ -237,7 +238,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
const int n_embd = hparams.n_embd; const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer; const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer); model.layers.resize(n_layer);
@ -296,8 +296,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
const int n_mem = n_layer*n_ctx; const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem; const int n_elements = n_embd*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); model.memory_k = ggml_new_tensor_1d(ctx, memory_type, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, memory_type, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
@ -539,9 +539,7 @@ bool llama_eval(
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_embd/hparams.n_head; const int n_rot = hparams.n_embd/hparams.n_head;
const int d_key = n_embd/n_head; // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
// TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
// static size_t buf_size = hparams.n_ctx*1024*1024; // static size_t buf_size = hparams.n_ctx*1024*1024;
static size_t buf_size = 512u*1024*1024; static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size); static void * buf = malloc(buf_size);
@ -752,6 +750,7 @@ static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) { void sigint_handler(int signo) {
printf(ANSI_COLOR_RESET); printf(ANSI_COLOR_RESET);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) { if (signo == SIGINT) {
if (!is_interacting) { if (!is_interacting) {
is_interacting=true; is_interacting=true;
@ -792,7 +791,7 @@ int main(int argc, char ** argv) {
if (gpt_params_parse(argc, argv, params) == false) { if (gpt_params_parse(argc, argv, params) == false) {
return 1; return 1;
} }
if (params.n_ctx > 2048) { if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);" fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx); "expect poor results\n", __func__, params.n_ctx);
@ -805,7 +804,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
std::mt19937 rng(params.seed); std::mt19937 rng(params.seed);
if (params.prompt.empty()) { if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng); params.prompt = gpt_random_prompt(rng);
} }
@ -819,8 +818,9 @@ int main(int argc, char ** argv) {
// load the model // load the model
{ {
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1; return 1;
} }
@ -849,8 +849,27 @@ int main(int argc, char ** argv) {
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
// prefix & suffix for instruct mode
const std::vector<gpt_vocab::id> inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true);
const std::vector<gpt_vocab::id> inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false);
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
params.interactive = true;
params.antiprompt.push_back("### Instruction:\n\n");
}
// tokenize the reverse prompt // tokenize the reverse prompt
std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false); std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;
for (auto antiprompt : params.antiprompt) {
antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
}
// enable interactive mode if reverse prompt is specified
if (antipromptv_inp.size() != 0) {
params.interactive = true;
}
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
@ -872,13 +891,16 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: interactive mode on.\n", __func__); fprintf(stderr, "%s: interactive mode on.\n", __func__);
if(antiprompt_inp.size()) { if(antipromptv_inp.size()) {
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str()); for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) {
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); auto antiprompt_inp = antipromptv_inp.at(apindex);
for (int i = 0; i < (int) antiprompt_inp.size(); i++) { fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
}
fprintf(stderr, "\n");
} }
fprintf(stderr, "\n");
} }
} }
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
@ -894,31 +916,27 @@ int main(int argc, char ** argv) {
std::vector<gpt_vocab::id> last_n_tokens(last_n_size); std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.interactive) { if (params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n" fprintf(stderr, "== Running in interactive mode. ==\n"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
" - Press Ctrl+C to interject at any time.\n" " - Press Ctrl+C to interject at any time.\n"
#endif #endif
" - Press Return to return control to LLaMa.\n" " - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n"); " - If you want to submit another line, end your input in '\\'.\n\n");
is_interacting = true;
} }
int remaining_tokens = params.n_predict;
int input_consumed = 0; int input_consumed = 0;
bool input_noecho = false; bool input_noecho = false;
// prompt user immediately after the starting prompt has been loaded int remaining_tokens = params.n_predict;
if (params.interactive_start) {
is_interacting = true;
}
// set the color for the prompt which will be output initially // set the color for the prompt which will be output initially
if (params.use_color) { if (params.use_color) {
printf(ANSI_COLOR_YELLOW); printf(ANSI_COLOR_YELLOW);
} }
while (remaining_tokens > 0) { while (remaining_tokens > 0 || params.interactive) {
// predict // predict
if (embd.size() > 0) { if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
@ -948,6 +966,11 @@ int main(int argc, char ** argv) {
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
if (params.ignore_eos) {
// set the logit of the eos token to zero to avoid sampling it
logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0;
}
id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
@ -971,15 +994,10 @@ int main(int argc, char ** argv) {
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[input_consumed]); last_n_tokens.push_back(embd_inp[input_consumed]);
++input_consumed; ++input_consumed;
if (embd.size() > params.n_batch) { if ((int) embd.size() >= params.n_batch) {
break; break;
} }
} }
// reset color to default if we there is no pending user input
if (!input_noecho && params.use_color && embd_inp.size() == input_consumed) {
printf(ANSI_COLOR_RESET);
}
} }
// display text // display text
@ -989,56 +1007,74 @@ int main(int argc, char ** argv) {
} }
fflush(stdout); fflush(stdout);
} }
// reset color to default if we there is no pending user input
if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) {
printf(ANSI_COLOR_RESET);
}
// in interactive mode, and not currently processing queued inputs; // in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more // check if we should prompt the user for more
if (params.interactive && embd_inp.size() <= input_consumed) { if (params.interactive && embd_inp.size() <= input_consumed) {
// check for reverse prompt // check for reverse prompt
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { for (auto antiprompt_inp : antipromptv_inp) {
// reverse prompt found if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
is_interacting = true; // reverse prompt found
is_interacting = true;
break;
}
} }
if (is_interacting) { if (is_interacting) {
// currently being interactive if (params.instruct) {
bool another_line=true; input_consumed = embd_inp.size();
while (another_line) { embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
fflush(stdout);
char buf[256] = {0};
int n_read;
if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
// presumable empty line, consume the newline
std::ignore = scanf("%*c");
n_read=0;
}
if(params.use_color) printf(ANSI_COLOR_RESET);
if (n_read > 0 && buf[n_read-1]=='\\') { printf("\n> ");
another_line = true;
buf[n_read-1] = '\n';
buf[n_read] = 0;
} else {
another_line = false;
buf[n_read] = '\n';
buf[n_read+1] = 0;
}
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buf, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
remaining_tokens -= line_inp.size();
input_noecho = true; // do not echo this again
} }
is_interacting = false; // currently being interactive
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
std::string buffer;
std::string line;
bool another_line = true;
do {
std::getline(std::cin, line);
if (line.empty() || line.back() != '\\') {
another_line = false;
} else {
line.pop_back(); // Remove the continue character
}
buffer += line + '\n'; // Append the line to the result
} while (another_line);
if (params.use_color) printf(ANSI_COLOR_RESET);
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}
remaining_tokens -= line_inp.size();
input_noecho = true; // do not echo this again
} }
is_interacting = false;
} }
// end of text token // end of text token
if (embd.back() == 2) { if (embd.back() == EOS_TOKEN_ID) {
fprintf(stderr, " [end of text]\n"); if (params.interactive) {
break; is_interacting = true;
} else {
fprintf(stderr, " [end of text]\n");
break;
}
}
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
if (params.interactive && remaining_tokens <= 0) {
remaining_tokens = params.n_predict;
is_interacting = true;
} }
} }

1
prompts/alpaca.txt Normal file
View file

@ -0,0 +1 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request.

View file

@ -0,0 +1,7 @@
Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia.
User:

126
quantize.py Normal file
View file

@ -0,0 +1,126 @@
#!/usr/bin/env python3
"""Script to execute the "quantize" script on a given set of models."""
import subprocess
import argparse
import glob
import sys
import os
def main():
"""Update the quantize binary name depending on the platform and parse
the command line arguments and execute the script.
"""
if "linux" in sys.platform or "darwin" in sys.platform:
quantize_script_binary = "quantize"
elif "win32" in sys.platform or "cygwin" in sys.platform:
quantize_script_binary = "quantize.exe"
else:
print("WARNING: Unknown platform. Assuming a UNIX-like OS.\n")
quantize_script_binary = "quantize"
parser = argparse.ArgumentParser(
prog='python3 quantize.py',
description='This script quantizes the given models by applying the '
f'"{quantize_script_binary}" script on them.'
)
parser.add_argument(
'models', nargs='+', choices=('7B', '13B', '30B', '65B'),
help='The models to quantize.'
)
parser.add_argument(
'-r', '--remove-16', action='store_true', dest='remove_f16',
help='Remove the f16 model after quantizing it.'
)
parser.add_argument(
'-m', '--models-path', dest='models_path',
default=os.path.join(os.getcwd(), "models"),
help='Specify the directory where the models are located.'
)
parser.add_argument(
'-q', '--quantize-script-path', dest='quantize_script_path',
default=os.path.join(os.getcwd(), quantize_script_binary),
help='Specify the path to the "quantize" script.'
)
# TODO: Revise this code
# parser.add_argument(
# '-t', '--threads', dest='threads', type='int',
# default=os.cpu_count(),
# help='Specify the number of threads to use to quantize many models at '
# 'once. Defaults to os.cpu_count().'
# )
args = parser.parse_args()
if not os.path.isfile(args.quantize_script_path):
print(
f'The "{quantize_script_binary}" script was not found in the '
"current location.\nIf you want to use it from another location, "
"set the --quantize-script-path argument from the command line."
)
sys.exit(1)
for model in args.models:
# The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
f16_model_path_base = os.path.join(
args.models_path, model, "ggml-model-f16.bin"
)
f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*")
)
for f16_model_part_path in f16_model_parts_paths:
if not os.path.isfile(f16_model_part_path):
print(
f"The f16 model {os.path.basename(f16_model_part_path)} "
f"was not found in {args.models_path}{os.path.sep}{model}"
". If you want to use it from another location, set the "
"--models-path argument from the command line."
)
sys.exit(1)
__run_quantize_script(
args.quantize_script_path, f16_model_part_path
)
if args.remove_f16:
os.remove(f16_model_part_path)
# This was extracted to a top-level function for parallelization, if
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
def __run_quantize_script(script_path, f16_model_part_path):
"""Run the quantize script specifying the path to it and the path to the
f16 model to quantize.
"""
new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0")
subprocess.run(
[script_path, f16_model_part_path, new_quantized_model_path, "2"],
check=True
)
if __name__ == "__main__":
try:
main()
except subprocess.CalledProcessError:
print("\nAn error ocurred while trying to quantize the models.")
sys.exit(1)
except KeyboardInterrupt:
sys.exit(0)
else:
print("\nSuccesfully quantized all models.")

View file

@ -1,15 +0,0 @@
#!/usr/bin/env bash
if ! [[ "$1" =~ ^[0-9]{1,2}B$ ]]; then
echo
echo "Usage: quantize.sh 7B|13B|30B|65B [--remove-f16]"
echo
exit 1
fi
for i in `ls models/$1/ggml-model-f16.bin*`; do
./quantize "$i" "${i/f16/q4_0}" 2
if [[ "$2" == "--remove-f16" ]]; then
rm "$i"
fi
done

View file

@ -38,19 +38,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "-p" || arg == "--prompt") { } else if (arg == "-p" || arg == "--prompt") {
params.prompt = argv[++i]; params.prompt = argv[++i];
} else if (arg == "-f" || arg == "--file") { } else if (arg == "-f" || arg == "--file") {
std::ifstream file(argv[++i]); std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
std::copy(std::istreambuf_iterator<char>(file), if (params.prompt.back() == '\n') {
std::istreambuf_iterator<char>(), params.prompt.pop_back();
back_inserter(params.prompt)); }
} else if (arg == "-n" || arg == "--n_predict") { } else if (arg == "-n" || arg == "--n_predict") {
params.n_predict = std::stoi(argv[++i]); params.n_predict = std::stoi(argv[++i]);
} else if (arg == "--top_k") { } else if (arg == "--top_k") {
params.top_k = std::stoi(argv[++i]); params.top_k = std::stoi(argv[++i]);
} else if (arg == "-c" || arg == "--ctx_size") { } else if (arg == "-c" || arg == "--ctx_size") {
params.n_ctx = std::stoi(argv[++i]); params.n_ctx = std::stoi(argv[++i]);
} else if (arg == "--memory_f16") {
params.memory_f16 = true;
} else if (arg == "--top_p") { } else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]); params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") { } else if (arg == "--temp") {
@ -65,16 +65,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.model = argv[++i]; params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") { } else if (arg == "-i" || arg == "--interactive") {
params.interactive = true; params.interactive = true;
} else if (arg == "--interactive-start") { } else if (arg == "-ins" || arg == "--instruct") {
params.interactive = true; params.instruct = true;
params.interactive_start = true;
} else if (arg == "--color") { } else if (arg == "--color") {
params.use_color = true; params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt = argv[++i]; params.antiprompt.push_back(argv[++i]);
} else if (arg == "--ignore-eos") {
params.ignore_eos = true;
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params); gpt_print_usage(argc, argv, params);
exit(0); exit(0);
} else if (arg == "--random-prompt") {
params.random_prompt = true;
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params); gpt_print_usage(argc, argv, params);
@ -85,20 +88,22 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
return true; return true;
} }
void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " -i, --interactive run in interactive mode\n");
fprintf(stderr, " --interactive-start run in interactive mode and poll user input at startup\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT\n"); fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: random)\n"); fprintf(stderr, " prompt to start generation with (default: empty)\n");
fprintf(stderr, " --random-prompt start with a randomized prompt.\n");
fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
@ -107,6 +112,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
fprintf(stderr, " --memory_f16 use f16 instead of f32 for memory key+value\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " -m FNAME, --model FNAME\n");
@ -398,7 +405,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
} else { } else {
logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
} }
} else { } else {
logits_id.push_back(std::make_pair(logits[i]*scale, i)); logits_id.push_back(std::make_pair(logits[i]*scale, i));
} }

11
utils.h
View file

@ -18,6 +18,7 @@ struct gpt_params {
int32_t n_predict = 128; // new tokens to predict int32_t n_predict = 128; // new tokens to predict
int32_t repeat_last_n = 64; // last n tokens to penalize int32_t repeat_last_n = 64; // last n tokens to penalize
int32_t n_ctx = 512; //context size int32_t n_ctx = 512; //context size
bool memory_f16 = false; // use f16 instead of f32 for memory kv
// sampling parameters // sampling parameters
int32_t top_k = 40; int32_t top_k = 40;
@ -27,14 +28,18 @@ struct gpt_params {
int32_t n_batch = 8; // batch size for prompt processing int32_t n_batch = 8; // batch size for prompt processing
std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt; std::string prompt = "";
bool random_prompt = false;
bool use_color = false; // use color to distinguish generations and inputs bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode bool interactive = false; // interactive mode
bool interactive_start = false; // reverse prompt immediately bool interactive_start = false; // reverse prompt immediately
std::string antiprompt = ""; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos
}; };
bool gpt_params_parse(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);