Merge branch 'master' into concedo

# Conflicts:
#	.devops/full.Dockerfile
#	README.md
#	main.cpp
This commit is contained in:
Concedo 2023-03-20 20:58:27 +08:00
commit a2c10e0d2f
15 changed files with 564 additions and 296 deletions

View file

@ -31,7 +31,7 @@ endif
# #
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++17 -fPIC
LDFLAGS = LDFLAGS =
# OS specific # OS specific

6
alpaca.sh Executable file
View file

@ -0,0 +1,6 @@
#!/bin/bash
#
# Temporary script - will be removed in the future
#
./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins --top_k 10000 --temp 0.96 --repeat_penalty 1 -t 7

View file

@ -16,7 +16,7 @@
# At the start of the ggml file we write the model parameters # At the start of the ggml file we write the model parameters
# and vocabulary. # and vocabulary.
# #
import os import argparse
import sys import sys
import json import json
import struct import struct
@ -24,136 +24,83 @@ import numpy as np
import torch import torch
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
if len(sys.argv) < 3: def parse_args():
print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)
# output in the same directory as the model parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
dir_model = sys.argv[1] parser.add_argument('dir_model', help='directory containing the model checkpoint')
parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)')
fname_hparams = sys.argv[1] + "/params.json" return parser.parse_args()
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
def get_n_parts(dim): def get_n_parts(dim):
if dim == 4096:
return 1 mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
elif dim == 5120: n_parts = mappings.get(dim)
return 2 if n_parts is None:
elif dim == 6656: print(f"Invalid dim: {dim}")
return 4
elif dim == 8192:
return 8
else:
print("Invalid dim: " + str(dim))
sys.exit(1) sys.exit(1)
# possible data types print(f"n_parts = {n_parts}\n")
# ftype == 0 -> float32 return n_parts
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]
ftype = 1 def load_hparams_and_tokenizer(dir_model):
if len(sys.argv) > 2:
ftype = int(sys.argv[2])
if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))
sys.exit(1)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
if os.path.exists(fname_out): fname_hparams = f"{dir_model}/params.json"
print(f"Skip conversion, it already exists: {fname_out}") fname_tokenizer = f"{dir_model}/../tokenizer.model"
sys.exit(0)
with open(fname_hparams, "r") as f: with open(fname_hparams, "r") as f:
hparams = json.load(f) hparams = json.load(f)
print(hparams)
tokenizer = SentencePieceProcessor(fname_tokenizer) tokenizer = SentencePieceProcessor(fname_tokenizer)
hparams.update({"vocab_size": tokenizer.vocab_size()})
hparams.update({"vocab_size": tokenizer.vocab_size()}) return hparams, tokenizer
n_parts = get_n_parts(hparams["dim"]) def write_header(fout, hparams, ftype):
print(hparams) keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
print('n_parts = ', n_parts) values = [
0x67676d66, # magic: ggml in hex
1, # file version
*[hparams[key] for key in keys],
hparams["dim"] // hparams["n_heads"], # rot (obsolete)
ftype
]
fout.write(struct.pack("i" * len(values), *values))
for p in range(n_parts): def write_tokens(fout, tokenizer):
print('Processing part ', p)
#fname_model = sys.argv[1] + "/consolidated.00.pth"
fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
if (p > 0):
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
model = torch.load(fname_model, map_location="cpu")
fout = open(fname_out, "wb")
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("i", hparams["dim"]))
fout.write(struct.pack("i", hparams["multiple_of"]))
fout.write(struct.pack("i", hparams["n_heads"]))
fout.write(struct.pack("i", hparams["n_layers"]))
fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
fout.write(struct.pack("i", ftype))
# Is this correct??
for i in range(tokenizer.vocab_size()): for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i): if tokenizer.is_unknown(i):
# "<unk>" token (translated as ??)
text = " \u2047 ".encode("utf-8") text = " \u2047 ".encode("utf-8")
fout.write(struct.pack("i", len(text)))
fout.write(text)
elif tokenizer.is_control(i): elif tokenizer.is_control(i):
# "<s>"/"</s>" tokens text = b""
fout.write(struct.pack("i", 0))
elif tokenizer.is_byte(i): elif tokenizer.is_byte(i):
# "<U+XX>" tokens (which may be invalid UTF-8)
piece = tokenizer.id_to_piece(i) piece = tokenizer.id_to_piece(i)
if len(piece) != 6: if len(piece) != 6:
print("Invalid token: " + piece) print(f"Invalid token: {piece}")
sys.exit(1) sys.exit(1)
byte_value = int(piece[3:-1], 16) byte_value = int(piece[3:-1], 16)
fout.write(struct.pack("i", 1)) text = struct.pack("B", byte_value)
fout.write(struct.pack("B", byte_value))
else: else:
# normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
fout.write(struct.pack("i", len(text))) fout.write(struct.pack("i", len(text)))
fout.write(text) fout.write(text)
fout.write(struct.pack("f", tokenizer.get_score(i)))
for k, v in model.items(): def process_and_write_variables(fout, model, ftype):
name = k
shape = v.shape
# skip layers.X.attention.inner_attention.rope.freqs for name, datao in model.items():
if name[-5:] == "freqs":
if name.endswith("freqs"):
continue continue
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) shape = datao.shape
#data = tf.train.load_variable(dir_model, name).squeeze() print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")
data = v.numpy().squeeze()
n_dims = len(data.shape);
# for efficiency - transpose some matrices data = datao.numpy().squeeze()
# "model/h.*/attn/c_attn/w" n_dims = len(shape)
# "model/h.*/attn/c_proj/w"
# "model/h.*/mlp/c_fc/w"
# "model/h.*/mlp/c_proj/w"
#if name[-14:] == "/attn/c_attn/w" or \
# name[-14:] == "/attn/c_proj/w" or \
# name[-11:] == "/mlp/c_fc/w" or \
# name[-13:] == "/mlp/c_proj/w":
# print(" Transposing")
# data = data.transpose()
dshape = data.shape
# default type is fp16 # default type is fp16
ftype_cur = 1 ftype_cur = 1
@ -164,18 +111,40 @@ for p in range(n_parts):
# header # header
sname = name.encode('utf-8') sname = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
for i in range(n_dims): for dim in reversed(data.shape):
fout.write(struct.pack("i", dshape[n_dims - 1 - i])) fout.write(struct.pack("i", dim))
fout.write(sname); fout.write(sname)
# data # data output to file
data.tofile(fout) data.tofile(fout)
# I hope this deallocates the memory .. def main():
model = None
fout.close() args = parse_args()
dir_model = args.dir_model
ftype = args.ftype
ftype_str = ["f32", "f16"]
print("Done. Output file: " + fname_out + ", (part ", p, ")") hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
print("") n_parts = get_n_parts(hparams["dim"])
for p in range(n_parts):
print(f"Processing part {p}\n")
fname_model = f"{dir_model}/consolidated.0{p}.pth"
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
model = torch.load(fname_model, map_location="cpu")
with open(fname_out, "wb") as fout:
write_header(fout, hparams, ftype)
write_tokens(fout, tokenizer)
process_and_write_variables(fout, model, ftype)
del model
print(f"Done. Output file: {fname_out}, (part {p})\n")
if __name__ == "__main__":
main()

View file

@ -55,7 +55,7 @@ extern "C" {
int n_parts_overwrite = inputs.n_parts_overwrite; int n_parts_overwrite = inputs.n_parts_overwrite;
if (!llama_model_load(api_params.model, api_model, api_vocab, api_params.n_ctx, n_parts_overwrite)) { if (!llama_model_load(api_params.model, api_model, api_vocab, api_params.n_ctx, GGML_TYPE_F16, n_parts_overwrite)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, api_params.model.c_str()); fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, api_params.model.c_str());
return false; return false;
} }
@ -101,7 +101,6 @@ extern "C" {
if(reset_state) if(reset_state)
{ {
api_params.prompt.insert(0, 1, ' '); api_params.prompt.insert(0, 1, ' ');
mem_per_token = 0;
} }
// tokenize the prompt // tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(api_vocab, api_params.prompt, true); std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(api_vocab, api_params.prompt, true);
@ -164,7 +163,7 @@ extern "C" {
{ {
// set the logit of the eos token (2) to zero to avoid sampling it // set the logit of the eos token (2) to zero to avoid sampling it
api_logits[api_logits.size() - n_vocab + 2] = 0; api_logits[api_logits.size() - n_vocab + EOS_TOKEN_ID] = 0;
//set logits of opening square bracket to zero. //set logits of opening square bracket to zero.
api_logits[api_logits.size() - n_vocab + 518] = 0; api_logits[api_logits.size() - n_vocab + 518] = 0;
api_logits[api_logits.size() - n_vocab + 29961] = 0; api_logits[api_logits.size() - n_vocab + 29961] = 0;

Binary file not shown.

267
main.cpp
View file

@ -3,10 +3,12 @@
#include "utils.h" #include "utils.h"
#include <cassert> #include <cassert>
#include <cinttypes>
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include <iostream>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
@ -27,6 +29,8 @@
#define ANSI_COLOR_RESET "\x1b[0m" #define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m" #define ANSI_BOLD "\x1b[1m"
static const int EOS_TOKEN_ID = 2;
// determine number of model parts based on the dimension // determine number of model parts based on the dimension
static const std::map<int, int> LLAMA_N_PARTS = { static const std::map<int, int> LLAMA_N_PARTS = {
{ 4096, 1 }, { 4096, 1 },
@ -86,30 +90,40 @@ struct llama_model {
}; };
// load the model's weights from a file // load the model's weights from a file
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, int n_parts_overwrite=-1) { bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32, int n_parts_overwrite=-1) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
std::vector<char> f_buf(1024*1024);
FILE *fin = fopen(fname.data(), "rb"); auto fin = std::ifstream(fname, std::ios::binary);
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
if (!fin) { if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false; return false;
} }
// Having a large buffer helps to accelerate load considerably (old buffer was 1024 * 1024).
// Though I am not sure if it's okay for edge devices like Raspberry Pi.
std::vector<char> f_buf(128 * 1024 * 1024);
setvbuf(fin, f_buf.data(), _IOFBF, f_buf.size());
// verify magic // verify magic
{ {
uint32_t magic; uint32_t magic;
fread((char *) &magic, 1, sizeof(magic), fin); fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) { if (magic == 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
__func__, fname.c_str());
return false;
}
if (magic != 0x67676d66) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false; return false;
} }
uint32_t format_version;
fin.read((char *) &format_version, sizeof(format_version));
if (format_version != 1) {
fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ")\n",
__func__, fname.c_str(), format_version);
return false;
}
} }
int n_ff = 0; int n_ff = 0;
@ -119,14 +133,14 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
{ {
auto & hparams = model.hparams; auto & hparams = model.hparams;
fread((char *) &hparams.n_vocab, 1, sizeof(hparams.n_vocab), fin); fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
//fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); //fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
fread((char *) &hparams.n_embd, 1, sizeof(hparams.n_embd), fin); fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fread((char *) &hparams.n_mult, 1, sizeof(hparams.n_mult), fin); fin.read((char *) &hparams.n_mult, sizeof(hparams.n_mult));
fread((char *) &hparams.n_head, 1, sizeof(hparams.n_head), fin); fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fread((char *) &hparams.n_layer, 1, sizeof(hparams.n_layer), fin); fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fread((char *) &hparams.n_rot, 1, sizeof(hparams.n_rot), fin); fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
fread((char *) &hparams.f16, 1, sizeof(hparams.f16), fin); fin.read((char *) &hparams.f16, sizeof(hparams.f16));
hparams.n_ctx = n_ctx; hparams.n_ctx = n_ctx;
@ -154,13 +168,17 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
std::string word; std::string word;
for (int i = 0; i < model.hparams.n_vocab; i++) { for (int i = 0; i < model.hparams.n_vocab; i++) {
uint32_t len; uint32_t len;
fread((char *) &len, 1, sizeof(len), fin); fin.read((char *) &len, sizeof(len));
word.resize(len); word.resize(len);
fread((char *) word.data(), 1, len, fin); fin.read((char *) word.data(), len);
float score;
fin.read((char *) &score, sizeof(score));
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word; vocab.id_to_token[i] = word;
vocab.score[i] = score;
//if (i < 30000) { //if (i < 30000) {
// fprintf(stderr, "%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); // fprintf(stderr, "%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
@ -184,8 +202,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
} }
} }
const ggml_type wtype2 = GGML_TYPE_F32;
auto & ctx = model.ctx; auto & ctx = model.ctx;
size_t ctx_size = 0; size_t ctx_size = 0;
@ -217,8 +233,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2 ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3 ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_k ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_v ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v
ctx_size += (5 + 10*n_layer)*256; // object overhead ctx_size += (5 + 10*n_layer)*256; // object overhead
@ -245,7 +261,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
const int n_embd = hparams.n_embd; const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer; const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer); model.layers.resize(n_layer);
@ -304,17 +319,17 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
const int n_mem = n_layer*n_ctx; const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem; const int n_elements = n_embd*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_k = ggml_new_tensor_1d(ctx, memory_type, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, memory_type, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
fprintf(stderr, "%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); fprintf(stderr, "%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
} }
const size_t file_offset = ftell(fin); const size_t file_offset = fin.tellg();
fclose(fin); fin.close();
std::vector<uint8_t> tmp; std::vector<uint8_t> tmp;
@ -329,9 +344,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str()); fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str());
fin = fopen(fname_part.data(), "rb"); fin = std::ifstream(fname_part, std::ios::binary);
setvbuf(fin, f_buf.data(), _IOFBF, f_buf.size()); fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
fseek(fin, file_offset, SEEK_CUR); fin.seekg(file_offset);
// load weights // load weights
{ {
int n_tensors = 0; int n_tensors = 0;
@ -344,24 +360,23 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
int32_t length; int32_t length;
int32_t ftype; int32_t ftype;
fread(reinterpret_cast<char *>(&n_dims), 1, sizeof(n_dims), fin); fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fread(reinterpret_cast<char *>(&length), 1, sizeof(length), fin); fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fread(reinterpret_cast<char *>(&ftype), 1, sizeof(ftype), fin); fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
if (fin.eof()) {
if (feof(fin)) {
break; break;
} }
int32_t nelements = 1; int32_t nelements = 1;
int32_t ne[2] = { 1, 1 }; int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) { for (int i = 0; i < n_dims; ++i) {
fread(reinterpret_cast<char *>(&ne[i]), 1, sizeof(ne[i]), fin); fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i]; nelements *= ne[i];
} }
std::string name(length, 0); std::string name(length, 0);
fread(&name[0], 1, length, fin); fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) { if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
@ -463,9 +478,9 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
} }
if (part_id == 0) { if (part_id == 0) {
fread(reinterpret_cast<char *>(tensor->data), 1, ggml_nbytes(tensor), fin); fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
} else { } else {
fseek(fin, ggml_nbytes(tensor), SEEK_CUR); fin.seekg(ggml_nbytes(tensor), std::ios::cur);
} }
total_size += ggml_nbytes(tensor); total_size += ggml_nbytes(tensor);
@ -485,7 +500,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
for (int i1 = 0; i1 < ne[1]; ++i1) { for (int i1 = 0; i1 < ne[1]; ++i1) {
const size_t offset_row = i1*row_size; const size_t offset_row = i1*row_size;
const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
fread(reinterpret_cast<char *>(tensor->data) + offset, 1, row_size/n_parts, fin); fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
} }
} else { } else {
const int np1 = ne[1]; const int np1 = ne[1];
@ -494,7 +509,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
for (int i1 = 0; i1 < ne[1]; ++i1) { for (int i1 = 0; i1 < ne[1]; ++i1) {
const size_t offset_row = (i1 + part_id*np1)*row_size; const size_t offset_row = (i1 + part_id*np1)*row_size;
fread(reinterpret_cast<char *>(tensor->data) + offset_row, 1, row_size, fin); fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
} }
} }
@ -513,11 +528,12 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
} }
fclose(fin); fin.close();
} }
return true; return true;
} }
// evaluate the transformer // evaluate the transformer
// //
// - model: the model // - model: the model
@ -546,9 +562,9 @@ bool llama_eval(
const int n_vocab = hparams.n_vocab; const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_embd/hparams.n_head; const int n_rot = hparams.n_embd/hparams.n_head;
const int d_key = n_embd/n_head; // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
// static size_t buf_size = hparams.n_ctx*1024*1024;
static size_t buf_size = (size_t)hparams.n_ctx*1024*1024; static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size); static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) { if (mem_per_token > 0 && mem_per_token*N > buf_size) {
@ -757,6 +773,7 @@ static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) { void sigint_handler(int signo) {
printf(ANSI_COLOR_RESET); printf(ANSI_COLOR_RESET);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) { if (signo == SIGINT) {
if (!is_interacting) { if (!is_interacting) {
is_interacting=true; is_interacting=true;
@ -797,7 +814,7 @@ int main(int argc, char ** argv) {
if (gpt_params_parse(argc, argv, params) == false) { if (gpt_params_parse(argc, argv, params) == false) {
return 1; return 1;
} }
if (params.n_ctx > 2048) { if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);" fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx); "expect poor results\n", __func__, params.n_ctx);
@ -810,7 +827,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
std::mt19937 rng(params.seed); std::mt19937 rng(params.seed);
if (params.prompt.empty()) { if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng); params.prompt = gpt_random_prompt(rng);
} }
@ -824,8 +841,9 @@ int main(int argc, char ** argv) {
// load the model // load the model
{ {
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1; return 1;
} }
@ -854,8 +872,27 @@ int main(int argc, char ** argv) {
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
// prefix & suffix for instruct mode
const std::vector<gpt_vocab::id> inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true);
const std::vector<gpt_vocab::id> inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false);
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
params.interactive = true;
params.antiprompt.push_back("### Instruction:\n\n");
}
// tokenize the reverse prompt // tokenize the reverse prompt
std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false); std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;
for (auto antiprompt : params.antiprompt) {
antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
}
// enable interactive mode if reverse prompt is specified
if (antipromptv_inp.size() != 0) {
params.interactive = true;
}
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
@ -877,13 +914,16 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: interactive mode on.\n", __func__); fprintf(stderr, "%s: interactive mode on.\n", __func__);
if(antiprompt_inp.size()) { if(antipromptv_inp.size()) {
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str()); for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) {
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); auto antiprompt_inp = antipromptv_inp.at(apindex);
for (int i = 0; i < (int) antiprompt_inp.size(); i++) { fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
}
fprintf(stderr, "\n");
} }
fprintf(stderr, "\n");
} }
} }
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
@ -899,31 +939,27 @@ int main(int argc, char ** argv) {
std::vector<gpt_vocab::id> last_n_tokens(last_n_size); std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.interactive) { if (params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n" fprintf(stderr, "== Running in interactive mode. ==\n"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
" - Press Ctrl+C to interject at any time.\n" " - Press Ctrl+C to interject at any time.\n"
#endif #endif
" - Press Return to return control to LLaMa.\n" " - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n"); " - If you want to submit another line, end your input in '\\'.\n\n");
is_interacting = true;
} }
int remaining_tokens = params.n_predict;
int input_consumed = 0; int input_consumed = 0;
bool input_noecho = false; bool input_noecho = false;
// prompt user immediately after the starting prompt has been loaded int remaining_tokens = params.n_predict;
if (params.interactive_start) {
is_interacting = true;
}
// set the color for the prompt which will be output initially // set the color for the prompt which will be output initially
if (params.use_color) { if (params.use_color) {
printf(ANSI_COLOR_YELLOW); printf(ANSI_COLOR_YELLOW);
} }
while (remaining_tokens > 0) { while (remaining_tokens > 0 || params.interactive) {
// predict // predict
if (embd.size() > 0) { if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
@ -953,8 +989,10 @@ int main(int argc, char ** argv) {
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
// set the logit of the eos token (2) to zero to avoid sampling it if (params.ignore_eos) {
logits[logits.size() - n_vocab + 2] = 0; // set the logit of the eos token to zero to avoid sampling it
logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0;
}
id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
@ -979,15 +1017,10 @@ int main(int argc, char ** argv) {
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[input_consumed]); last_n_tokens.push_back(embd_inp[input_consumed]);
++input_consumed; ++input_consumed;
if (embd.size() > params.n_batch) { if ((int) embd.size() >= params.n_batch) {
break; break;
} }
} }
// reset color to default if we there is no pending user input
if (!input_noecho && params.use_color && embd_inp.size() == input_consumed) {
printf(ANSI_COLOR_RESET);
}
} }
// display text // display text
@ -997,56 +1030,74 @@ int main(int argc, char ** argv) {
} }
fflush(stdout); fflush(stdout);
} }
// reset color to default if we there is no pending user input
if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) {
printf(ANSI_COLOR_RESET);
}
// in interactive mode, and not currently processing queued inputs; // in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more // check if we should prompt the user for more
if (params.interactive && embd_inp.size() <= input_consumed) { if (params.interactive && embd_inp.size() <= input_consumed) {
// check for reverse prompt // check for reverse prompt
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { for (auto antiprompt_inp : antipromptv_inp) {
// reverse prompt found if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
is_interacting = true; // reverse prompt found
is_interacting = true;
break;
}
} }
if (is_interacting) { if (is_interacting) {
// currently being interactive if (params.instruct) {
bool another_line=true; input_consumed = embd_inp.size();
while (another_line) { embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
fflush(stdout);
char buf[256] = {0};
int n_read;
if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
// presumable empty line, consume the newline
std::ignore = scanf("%*c");
n_read=0;
}
if(params.use_color) printf(ANSI_COLOR_RESET);
if (n_read > 0 && buf[n_read-1]=='\\') { printf("\n> ");
another_line = true;
buf[n_read-1] = '\n';
buf[n_read] = 0;
} else {
another_line = false;
buf[n_read] = '\n';
buf[n_read+1] = 0;
}
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buf, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
remaining_tokens -= line_inp.size();
input_noecho = true; // do not echo this again
} }
is_interacting = false; // currently being interactive
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
std::string buffer;
std::string line;
bool another_line = true;
do {
std::getline(std::cin, line);
if (line.empty() || line.back() != '\\') {
another_line = false;
} else {
line.pop_back(); // Remove the continue character
}
buffer += line + '\n'; // Append the line to the result
} while (another_line);
if (params.use_color) printf(ANSI_COLOR_RESET);
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}
remaining_tokens -= line_inp.size();
input_noecho = true; // do not echo this again
} }
is_interacting = false;
} }
// end of text token // end of text token
if (embd.back() == 2) { if (embd.back() == EOS_TOKEN_ID) {
fprintf(stderr, " [end of text]\n"); if (params.interactive) {
break; is_interacting = true;
} else {
fprintf(stderr, " [end of text]\n");
break;
}
}
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
if (params.interactive && remaining_tokens <= 0) {
remaining_tokens = params.n_predict;
is_interacting = true;
} }
} }

BIN
main.exe

Binary file not shown.

1
prompts/alpaca.txt Normal file
View file

@ -0,0 +1 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request.

View file

@ -0,0 +1,7 @@
Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia.
User:

View file

@ -3,6 +3,7 @@
#include "utils.h" #include "utils.h"
#include <cassert> #include <cassert>
#include <cinttypes>
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
@ -63,12 +64,28 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
{ {
uint32_t magic; uint32_t magic;
finp.read((char *) &magic, sizeof(magic)); finp.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) { if (magic == 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
__func__, fname_inp.c_str());
return false;
}
if (magic != 0x67676d66) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
return false; return false;
} }
fout.write((char *) &magic, sizeof(magic)); fout.write((char *) &magic, sizeof(magic));
uint32_t format_version;
finp.read((char *) &format_version, sizeof(format_version));
if (format_version != 1) {
fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ")\n",
__func__, fname_inp.c_str(), format_version);
return false;
}
fout.write((char *) &format_version, sizeof(format_version));
} }
llama_hparams hparams; llama_hparams hparams;
@ -122,8 +139,13 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
finp.read ((char *) word.data(), len); finp.read ((char *) word.data(), len);
fout.write((char *) word.data(), len); fout.write((char *) word.data(), len);
float score;
finp.read ((char *) &score, sizeof(score));
fout.write((char *) &score, sizeof(score));
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word; vocab.id_to_token[i] = word;
vocab.score[i] = score;
} }
} }

Binary file not shown.

126
quantize.py Normal file
View file

@ -0,0 +1,126 @@
#!/usr/bin/env python3
"""Script to execute the "quantize" script on a given set of models."""
import subprocess
import argparse
import glob
import sys
import os
def main():
"""Update the quantize binary name depending on the platform and parse
the command line arguments and execute the script.
"""
if "linux" in sys.platform or "darwin" in sys.platform:
quantize_script_binary = "quantize"
elif "win32" in sys.platform or "cygwin" in sys.platform:
quantize_script_binary = "quantize.exe"
else:
print("WARNING: Unknown platform. Assuming a UNIX-like OS.\n")
quantize_script_binary = "quantize"
parser = argparse.ArgumentParser(
prog='python3 quantize.py',
description='This script quantizes the given models by applying the '
f'"{quantize_script_binary}" script on them.'
)
parser.add_argument(
'models', nargs='+', choices=('7B', '13B', '30B', '65B'),
help='The models to quantize.'
)
parser.add_argument(
'-r', '--remove-16', action='store_true', dest='remove_f16',
help='Remove the f16 model after quantizing it.'
)
parser.add_argument(
'-m', '--models-path', dest='models_path',
default=os.path.join(os.getcwd(), "models"),
help='Specify the directory where the models are located.'
)
parser.add_argument(
'-q', '--quantize-script-path', dest='quantize_script_path',
default=os.path.join(os.getcwd(), quantize_script_binary),
help='Specify the path to the "quantize" script.'
)
# TODO: Revise this code
# parser.add_argument(
# '-t', '--threads', dest='threads', type='int',
# default=os.cpu_count(),
# help='Specify the number of threads to use to quantize many models at '
# 'once. Defaults to os.cpu_count().'
# )
args = parser.parse_args()
if not os.path.isfile(args.quantize_script_path):
print(
f'The "{quantize_script_binary}" script was not found in the '
"current location.\nIf you want to use it from another location, "
"set the --quantize-script-path argument from the command line."
)
sys.exit(1)
for model in args.models:
# The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
f16_model_path_base = os.path.join(
args.models_path, model, "ggml-model-f16.bin"
)
f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*")
)
for f16_model_part_path in f16_model_parts_paths:
if not os.path.isfile(f16_model_part_path):
print(
f"The f16 model {os.path.basename(f16_model_part_path)} "
f"was not found in {args.models_path}{os.path.sep}{model}"
". If you want to use it from another location, set the "
"--models-path argument from the command line."
)
sys.exit(1)
__run_quantize_script(
args.quantize_script_path, f16_model_part_path
)
if args.remove_f16:
os.remove(f16_model_part_path)
# This was extracted to a top-level function for parallelization, if
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
def __run_quantize_script(script_path, f16_model_part_path):
"""Run the quantize script specifying the path to it and the path to the
f16 model to quantize.
"""
new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0")
subprocess.run(
[script_path, f16_model_part_path, new_quantized_model_path, "2"],
check=True
)
if __name__ == "__main__":
try:
main()
except subprocess.CalledProcessError:
print("\nAn error ocurred while trying to quantize the models.")
sys.exit(1)
except KeyboardInterrupt:
sys.exit(0)
else:
print("\nSuccesfully quantized all models.")

View file

@ -1,15 +0,0 @@
#!/usr/bin/env bash
if ! [[ "$1" =~ ^[0-9]{1,2}B$ ]]; then
echo
echo "Usage: quantize.sh 7B|13B|30B|65B [--remove-f16]"
echo
exit 1
fi
for i in `ls models/$1/ggml-model-f16.bin*`; do
./quantize "$i" "${i/f16/q4_0}" 2
if [[ "$2" == "--remove-f16" ]]; then
rm "$i"
fi
done

204
utils.cpp
View file

@ -6,6 +6,7 @@
#include <regex> #include <regex>
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
#include <queue>
#include <string> #include <string>
#include <math.h> #include <math.h>
@ -38,19 +39,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "-p" || arg == "--prompt") { } else if (arg == "-p" || arg == "--prompt") {
params.prompt = argv[++i]; params.prompt = argv[++i];
} else if (arg == "-f" || arg == "--file") { } else if (arg == "-f" || arg == "--file") {
std::ifstream file(argv[++i]); std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
std::copy(std::istreambuf_iterator<char>(file), if (params.prompt.back() == '\n') {
std::istreambuf_iterator<char>(), params.prompt.pop_back();
back_inserter(params.prompt)); }
} else if (arg == "-n" || arg == "--n_predict") { } else if (arg == "-n" || arg == "--n_predict") {
params.n_predict = std::stoi(argv[++i]); params.n_predict = std::stoi(argv[++i]);
} else if (arg == "--top_k") { } else if (arg == "--top_k") {
params.top_k = std::stoi(argv[++i]); params.top_k = std::stoi(argv[++i]);
} else if (arg == "-c" || arg == "--ctx_size") { } else if (arg == "-c" || arg == "--ctx_size") {
params.n_ctx = std::stoi(argv[++i]); params.n_ctx = std::stoi(argv[++i]);
} else if (arg == "--memory_f16") {
params.memory_f16 = true;
} else if (arg == "--top_p") { } else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]); params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") { } else if (arg == "--temp") {
@ -65,16 +66,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.model = argv[++i]; params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") { } else if (arg == "-i" || arg == "--interactive") {
params.interactive = true; params.interactive = true;
} else if (arg == "--interactive-start") { } else if (arg == "-ins" || arg == "--instruct") {
params.interactive = true; params.instruct = true;
params.interactive_start = true;
} else if (arg == "--color") { } else if (arg == "--color") {
params.use_color = true; params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt = argv[++i]; params.antiprompt.push_back(argv[++i]);
} else if (arg == "--ignore-eos") {
params.ignore_eos = true;
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params); gpt_print_usage(argc, argv, params);
exit(0); exit(0);
} else if (arg == "--random-prompt") {
params.random_prompt = true;
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params); gpt_print_usage(argc, argv, params);
@ -85,20 +89,22 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
return true; return true;
} }
void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " -i, --interactive run in interactive mode\n");
fprintf(stderr, " --interactive-start run in interactive mode and poll user input at startup\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT\n"); fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: random)\n"); fprintf(stderr, " prompt to start generation with (default: empty)\n");
fprintf(stderr, " --random-prompt start with a randomized prompt.\n");
fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
@ -107,6 +113,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
fprintf(stderr, " --memory_f16 use f16 instead of f32 for memory key+value\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " -m FNAME, --model FNAME\n");
@ -287,58 +295,146 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
return tokens; return tokens;
} }
// TODO: Calculate this constant from the vocabulary static size_t utf8_len(char src) {
#define MAX_TOKEN_LEN 18 const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece uint8_t highbits = static_cast<uint8_t>(src) >> 4;
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) { return lookup[highbits];
std::vector<gpt_vocab::id> res; }
std::vector<int> score;
std::vector<gpt_vocab::id> prev;
int len = text.length();
score.resize(len + 1); struct llama_sp_symbol {
prev.resize(len + 1); using index = int;
index prev;
index next;
std::string_view text;
};
// Forward pass struct llama_sp_bigram {
for (int i = 0; i < len; i++) { struct comparator {
int max_len = std::min(len - i, MAX_TOKEN_LEN); bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) {
for (int sub_len = 1; sub_len <= max_len; sub_len++) { return (l.score < r.score) || (l.score == r.score && l.left > r.left);
auto sub = text.substr(i, sub_len); }
auto token = vocab.token_to_id.find(sub); };
if (token != vocab.token_to_id.end()) { using queue_storage = std::vector<llama_sp_bigram>;
int token_score = sub.length() * sub.length(); using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>;
int local_score = score[i] + token_score; llama_sp_symbol::index left;
int next = i + sub_len; llama_sp_symbol::index right;
if (score[next] < local_score) { float score;
score[next] = local_score; size_t size;
prev[next] = (*token).second; };
struct llama_tokenizer {
llama_tokenizer(const gpt_vocab & vocab): vocab_(vocab) {}
void tokenize(std::string_view text, std::vector<gpt_vocab::id> & output) {
// split string into utf8 chars
int index = 0;
while (!text.empty()) {
llama_sp_symbol sym;
size_t char_len = std::min(text.size(), utf8_len(text.data()[0]));
sym.text = std::string_view(text.data(), char_len);
sym.prev = index - 1;
text.remove_prefix(char_len);
sym.next = text.empty() ? -1 : index + 1;
index++;
symbols_.emplace_back(std::move(sym));
}
// seed the work queue with all possible 2-character tokens.
for (size_t i = 1; i < symbols_.size(); ++i) {
try_add_bigram(i - 1, i);
}
// keep substituting the highest frequency pairs for as long as we can.
while (!work_queue_.empty()) {
auto bigram = work_queue_.top();
work_queue_.pop();
auto & left_sym = symbols_[bigram.left];
auto & right_sym = symbols_[bigram.right];
// if one of the symbols already got merged, skip it.
if (left_sym.text.empty() || right_sym.text.empty() ||
left_sym.text.size() + right_sym.text.size() != bigram.size) {
continue;
}
// merge the right sym into the left one
left_sym.text = std::string_view(left_sym.text.data(), left_sym.text.size() + right_sym.text.size());
right_sym.text = std::string_view("");
// remove the right sym from the chain
left_sym.next = right_sym.next;
if (right_sym.next >= 0) {
symbols_[right_sym.next].prev = bigram.left;
}
// find more substitutions
try_add_bigram(left_sym.prev, bigram.left);
try_add_bigram(bigram.left, left_sym.next);
}
for (int i = 0; i != -1; i = symbols_[i].next) {
auto& symbol = symbols_[i];
auto token = vocab_.token_to_id.find(std::string(symbol.text));
if (token == vocab_.token_to_id.end()) {
// output any symbols that did not form tokens as bytes.
for (int j = 0; j < symbol.text.size(); ++j) {
gpt_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
output.push_back(token_id);
} }
} else {
output.push_back((*token).second);
} }
} }
} }
// Backward pass private:
int i = len; void try_add_bigram(int left, int right) {
while (i > 0) { if (left == -1 || right == -1) {
gpt_vocab::id token_id = prev[i]; return;
if (token_id == 0) {
// TODO: Return error or something more meaningful
printf("failed to tokenize string!\n");
break;
} }
res.push_back(token_id);
auto token = (*vocab.id_to_token.find(token_id)).second; std::string_view text(symbols_[left].text.data(), symbols_[left].text.size() + symbols_[right].text.size());
i -= token.length(); auto token = vocab_.token_to_id.find(std::string(text));
if (token == vocab_.token_to_id.end()) {
return;
}
auto score = vocab_.score.find((*token).second);
if (score == vocab_.score.end()) {
return;
}
llama_sp_bigram bigram;
bigram.left = left;
bigram.right = right;
bigram.score = (*score).second;
bigram.size = text.size();
work_queue_.push(bigram);
}
const gpt_vocab & vocab_;
std::vector<llama_sp_symbol> symbols_;
llama_sp_bigram::queue work_queue_;
};
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos) {
llama_tokenizer tokenizer(vocab);
std::vector<gpt_vocab::id> output;
if (text.size() == 0) {
return output;
} }
if (bos) { if (bos) {
res.push_back(1); // TODO: replace with vocab.bos output.push_back(1);
} }
// Pieces are in reverse order so correct that tokenizer.tokenize(text, output);
std::reverse(res.begin(), res.end()); return output;
return res;
} }
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
@ -398,7 +494,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
} else { } else {
logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
} }
} else { } else {
logits_id.push_back(std::make_pair(logits[i]*scale, i)); logits_id.push_back(std::make_pair(logits[i]*scale, i));
} }

14
utils.h
View file

@ -18,6 +18,7 @@ struct gpt_params {
int32_t n_predict = 128; // new tokens to predict int32_t n_predict = 128; // new tokens to predict
int32_t repeat_last_n = 64; // last n tokens to penalize int32_t repeat_last_n = 64; // last n tokens to penalize
int32_t n_ctx = 512; //context size int32_t n_ctx = 512; //context size
bool memory_f16 = false; // use f16 instead of f32 for memory kv
// sampling parameters // sampling parameters
int32_t top_k = 40; int32_t top_k = 40;
@ -27,14 +28,18 @@ struct gpt_params {
int32_t n_batch = 8; // batch size for prompt processing int32_t n_batch = 8; // batch size for prompt processing
std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt; std::string prompt = "";
bool random_prompt = false;
bool use_color = false; // use color to distinguish generations and inputs bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode bool interactive = false; // interactive mode
bool interactive_start = false; // reverse prompt immediately bool interactive_start = false; // reverse prompt immediately
std::string antiprompt = ""; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos
}; };
bool gpt_params_parse(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
@ -53,6 +58,7 @@ struct gpt_vocab {
std::map<token, id> token_to_id; std::map<token, id> token_to_id;
std::map<id, token> id_to_token; std::map<id, token> id_to_token;
std::map<id, float> score;
}; };
void replace(std::string & str, const std::string & needle, const std::string & replacement); void replace(std::string & str, const std::string & needle, const std::string & replacement);
@ -74,7 +80,7 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
// TODO: this is probably wrong, but I cannot figure out how this tokenizer works .. // TODO: this is probably wrong, but I cannot figure out how this tokenizer works ..
// ref: https://github.com/google/sentencepiece // ref: https://github.com/google/sentencepiece
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos); std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos);
// load the tokens from encoder.json // load the tokens from encoder.json
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab); bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);