Backwards compatibility formats all done
Merge branch 'master' into concedo # Conflicts: # CMakeLists.txt # README.md # llama.cpp
This commit is contained in:
commit
559a1967f7
21 changed files with 832 additions and 494 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -22,6 +22,7 @@ models/*
|
|||
/result
|
||||
/perplexity
|
||||
/embedding
|
||||
/Pipfile
|
||||
|
||||
arm_neon.h
|
||||
compile_commands.json
|
||||
|
|
5
Makefile
5
Makefile
|
@ -76,7 +76,10 @@ endif
|
|||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -mf16c
|
||||
F16C_M := $(shell sysctl machdep.cpu.features)
|
||||
ifneq (,$(findstring F16C,$(F16C_M)))
|
||||
CFLAGS += -mf16c
|
||||
endif
|
||||
AVX1_M := $(shell sysctl machdep.cpu.features)
|
||||
ifneq (,$(findstring FMA,$(AVX1_M)))
|
||||
CFLAGS += -mfma
|
||||
|
|
|
@ -21,12 +21,11 @@ What does it mean? You get llama.cpp with a fancy UI, persistent stories, editin
|
|||
|
||||
## Considerations
|
||||
- Don't want to use pybind11 due to dependencies on MSVCC
|
||||
- ZERO or MINIMAL changes as possible to main.cpp - do not move their function declarations elsewhere!
|
||||
- Leave main.cpp UNTOUCHED, We want to be able to update the repo and pull any changes automatically.
|
||||
- ZERO or MINIMAL changes as possible to parent repo files - do not move their function declarations elsewhere! We want to be able to update the repo and pull any changes automatically.
|
||||
- No dynamic memory allocation! Setup structs with FIXED (known) shapes and sizes for ALL output fields. Python will ALWAYS provide the memory, we just write to it.
|
||||
- No external libraries or dependencies. That means no Flask, Pybind and whatever. All You Need Is Python.
|
||||
- Since v1.0.6, requires libopenblas, the prebuilt windows binaries are included in this repo. If not found, it will fall back to a mode without BLAS. If you want you can also link your own install of OpenBLAS manually with `LLAMA_OPENBLAS=1`
|
||||
- I plan to keep backwards compatibility with all past ggml llama.cpp AND alpaca.cpp models.
|
||||
- **I plan to keep backwards compatibility with ALL past llama.cpp AND alpaca.cpp models**. But you are also encouraged to reconvert/update your models if possible for best results.
|
||||
|
||||
## License
|
||||
- The original GGML library and llama.cpp by ggerganov are licensed under the MIT License
|
||||
|
@ -34,5 +33,5 @@ What does it mean? You get llama.cpp with a fancy UI, persistent stories, editin
|
|||
- The provided python ctypes bindings in llamacpp.dll are also under the AGPL v3.0 License
|
||||
|
||||
## Notes
|
||||
- There is a fundamental flaw with llama.cpp, which causes generation delay to scale linearly with original prompt length. If you care, **please contribute to [this discussion](https://github.com/ggerganov/llama.cpp/discussions/229)** which, if resolved, will actually make this viable.
|
||||
- Generation delay scales linearly with original prompt length. See [this discussion](https://github.com/ggerganov/llama.cpp/discussions/229). If OpenBLAS is enabled then prompt ingestion becomes about 2-3x faster. This is automatic on windows, but will require linking on OSX and Linux.
|
||||
- I have heard of someone claiming a false AV positive report. The exe is a simple pyinstaller bundle that includes the necessary python scripts and dlls to run. If this still concerns you, you might wish to rebuild everything from source code using the makefile, and you can rebuild the exe yourself with pyinstaller by using `make_pyinstaller.bat`
|
|
@ -84,6 +84,11 @@ def read_variables(fin):
|
|||
shape = shape[::-1]
|
||||
name = fin.read(name_length).decode("utf-8")
|
||||
|
||||
# ensure tensor data is aligned
|
||||
tensor_data_offset = fin.tell()
|
||||
tensor_data_offset = (tensor_data_offset + 31) & -32
|
||||
fin.seek(tensor_data_offset)
|
||||
|
||||
if ftype_cur == 2:
|
||||
# 4-bit quantized weights
|
||||
dtype = np.uint8
|
||||
|
|
|
@ -72,6 +72,11 @@ def write_header(shape, dst_name, ftype_cur):
|
|||
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
|
||||
fout.write(sname)
|
||||
|
||||
# ensure tensor data is aligned
|
||||
tensor_data_offset = fout.tell()
|
||||
tensor_data_offset = (tensor_data_offset + 31) & -32
|
||||
fout.seek(tensor_data_offset)
|
||||
|
||||
def convert_non_q4(src_name, dst_name):
|
||||
v = model[src_name]
|
||||
shape = v.shape
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Convert a LLaMA model checkpoint to a ggml compatible file
|
||||
# Convert a LLaMA model checkpoint to a ggjt compatible file
|
||||
#
|
||||
# Load the model using Torch
|
||||
# Iterate over all variables and write them to a binary file.
|
||||
|
@ -24,8 +24,57 @@ import torch
|
|||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
def parse_args():
|
||||
QK = 32
|
||||
|
||||
GGML_TYPE_Q4_0 = 0
|
||||
GGML_TYPE_Q4_1 = 1
|
||||
GGML_TYPE_I8 = 2
|
||||
GGML_TYPE_I16 = 3
|
||||
GGML_TYPE_I32 = 4
|
||||
GGML_TYPE_F16 = 5
|
||||
GGML_TYPE_F32 = 6
|
||||
|
||||
WTYPES = {
|
||||
0: GGML_TYPE_F32,
|
||||
1: GGML_TYPE_F16,
|
||||
2: GGML_TYPE_Q4_0,
|
||||
3: GGML_TYPE_Q4_1,
|
||||
}
|
||||
|
||||
GGML_BLCK_SIZE = {
|
||||
GGML_TYPE_Q4_0: QK,
|
||||
GGML_TYPE_Q4_1: QK,
|
||||
GGML_TYPE_I8: 1,
|
||||
GGML_TYPE_I16: 1,
|
||||
GGML_TYPE_I32: 1,
|
||||
GGML_TYPE_F16: 1,
|
||||
GGML_TYPE_F32: 1,
|
||||
}
|
||||
|
||||
GGML_TYPE_SIZE = {
|
||||
GGML_TYPE_Q4_0: 4 + QK//2,
|
||||
GGML_TYPE_Q4_1: 4*2 + QK//2,
|
||||
GGML_TYPE_I8: 1,
|
||||
GGML_TYPE_I16: 2,
|
||||
GGML_TYPE_I32: 4,
|
||||
GGML_TYPE_F16: 2,
|
||||
GGML_TYPE_F32: 4,
|
||||
}
|
||||
|
||||
def ggml_nelements(shape):
|
||||
r = 1
|
||||
for i in shape:
|
||||
r *= i
|
||||
return r
|
||||
|
||||
def ggml_nbytes(shape, ftype):
|
||||
x = ggml_nelements(shape)
|
||||
t = WTYPES[ftype]
|
||||
x *= GGML_TYPE_SIZE[t]
|
||||
x //= GGML_BLCK_SIZE[t]
|
||||
return x
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
|
||||
parser.add_argument('dir_model', help='directory containing the model checkpoint')
|
||||
parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
|
||||
|
@ -33,7 +82,6 @@ def parse_args():
|
|||
return parser.parse_args()
|
||||
|
||||
def get_n_parts(dim):
|
||||
|
||||
mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
|
||||
n_parts = mappings.get(dim)
|
||||
if n_parts is None:
|
||||
|
@ -44,30 +92,24 @@ def get_n_parts(dim):
|
|||
return n_parts
|
||||
|
||||
def load_hparams_and_tokenizer(dir_model):
|
||||
|
||||
# `dir_model` is something like `models/7B` or `models/7B/`.
|
||||
# "tokenizer.model" is expected under model's parent dir.
|
||||
# When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found.
|
||||
# Let's use the model's parent dir directly.
|
||||
model_parent_dir = os.path.dirname(os.path.normpath(dir_model))
|
||||
|
||||
fname_hparams = f"{dir_model}/params.json"
|
||||
fname_tokenizer = f"{model_parent_dir}/tokenizer.model"
|
||||
|
||||
with open(fname_hparams, "r") as f:
|
||||
hparams = json.load(f)
|
||||
print(hparams)
|
||||
|
||||
tokenizer = SentencePieceProcessor(fname_tokenizer)
|
||||
hparams.update({"vocab_size": tokenizer.vocab_size()})
|
||||
|
||||
return hparams, tokenizer
|
||||
|
||||
def write_header(fout, hparams, ftype):
|
||||
|
||||
keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
|
||||
values = [
|
||||
0x67676d66, # magic: ggmf in hex
|
||||
0x67676a74, # magic: ggjt in hex
|
||||
1, # file version
|
||||
*[hparams[key] for key in keys],
|
||||
hparams["dim"] // hparams["n_heads"], # rot (obsolete)
|
||||
|
@ -76,7 +118,6 @@ def write_header(fout, hparams, ftype):
|
|||
fout.write(struct.pack("i" * len(values), *values))
|
||||
|
||||
def write_tokens(fout, tokenizer):
|
||||
|
||||
for i in range(tokenizer.vocab_size()):
|
||||
if tokenizer.is_unknown(i):
|
||||
text = " \u2047 ".encode("utf-8")
|
||||
|
@ -95,85 +136,139 @@ def write_tokens(fout, tokenizer):
|
|||
fout.write(text)
|
||||
fout.write(struct.pack("f", tokenizer.get_score(i)))
|
||||
|
||||
def process_and_write_variables(fout, model, ftype):
|
||||
|
||||
def process_and_write_variables(fout, model, ftype, part_id, n_parts):
|
||||
for name, datao in model.items():
|
||||
|
||||
if name.endswith("freqs"):
|
||||
continue
|
||||
|
||||
shape = datao.shape
|
||||
|
||||
print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")
|
||||
|
||||
# remove dimensions with a single element
|
||||
data = datao.numpy().squeeze()
|
||||
n_dims = len(shape)
|
||||
partshape = data.shape
|
||||
n_dims = len(data.shape)
|
||||
assert n_dims in (1, 2)
|
||||
|
||||
# default type is fp16
|
||||
print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}")
|
||||
|
||||
# coerce single-dimensional tensors from float16 to float32
|
||||
ftype_cur = 1
|
||||
if ftype == 0 or n_dims == 1:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]]
|
||||
type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]]
|
||||
|
||||
# header
|
||||
# determine dimension along which multipart tensor is sharded
|
||||
#
|
||||
# split_dim 0 regex:
|
||||
# - output.*
|
||||
# - layers.*.attention.wq.weight
|
||||
# - layers.*.attention.wk.weight
|
||||
# - layers.*.attention.wv.weight
|
||||
# - layers.*.feed_forward.w1.weight
|
||||
# - layers.*.feed_forward.w3.weight
|
||||
#
|
||||
# split_dim 1 regex:
|
||||
# - tok_embeddings.*
|
||||
# - layers.*.attention.wo.weight
|
||||
# - layers.*.feed_forward.w2.weight
|
||||
#
|
||||
if n_dims > 1:
|
||||
split_dim = 1
|
||||
if "tok_embeddings" in name:
|
||||
split_dim = 1
|
||||
elif "layers" in name:
|
||||
if "attention.wo.weight" in name:
|
||||
split_dim = 1
|
||||
elif "feed_forward.w2.weight" in name:
|
||||
split_dim = 1
|
||||
else:
|
||||
split_dim = 0
|
||||
elif "output" in name:
|
||||
split_dim = 0
|
||||
|
||||
# output tensor header
|
||||
fullshape = list(partshape)
|
||||
if n_dims > 1:
|
||||
fullshape[split_dim] *= n_parts
|
||||
sname = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
|
||||
for dim in reversed(data.shape):
|
||||
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
|
||||
for dim in reversed(fullshape):
|
||||
fout.write(struct.pack("i", dim))
|
||||
fout.write(sname)
|
||||
|
||||
# data output to file
|
||||
data.tofile(fout)
|
||||
# ensure tensor data is aligned
|
||||
tensor_data_offset = fout.tell()
|
||||
while tensor_data_offset % QK != 0:
|
||||
fout.write(struct.pack("B", 0))
|
||||
tensor_data_offset += 1
|
||||
|
||||
# output unified mappable tensor data
|
||||
if n_dims == 1 or n_parts == 1:
|
||||
# copy tensor which we thankfully received in one piece
|
||||
if part_id == 0:
|
||||
data.tofile(fout)
|
||||
elif split_dim == 0:
|
||||
# reassemble multifile tensor containing some of the rows
|
||||
rows_per_chunk = partshape[0]
|
||||
current_row = part_id * rows_per_chunk
|
||||
bytes_per_row = fullshape[1] // blck_size * type_size
|
||||
offset = current_row * bytes_per_row
|
||||
fout.seek(tensor_data_offset + offset)
|
||||
data.tofile(fout)
|
||||
elif split_dim == 1:
|
||||
# reassemble multifile tensor containing some of the cols
|
||||
cols_per_chunk = partshape[1]
|
||||
current_col = part_id * cols_per_chunk
|
||||
bytes_per_row = fullshape[1] // blck_size * type_size
|
||||
offset_current_col = current_col // blck_size * type_size
|
||||
for row in range(partshape[0]):
|
||||
offset_row = row * bytes_per_row
|
||||
offset = offset_row + offset_current_col
|
||||
fout.seek(tensor_data_offset + offset)
|
||||
data[row].tofile(fout)
|
||||
|
||||
# advance file position to next tensor
|
||||
fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur))
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
dir_model = args.dir_model
|
||||
ftype = args.ftype
|
||||
ftype_str = ["f32", "f16"]
|
||||
|
||||
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
|
||||
|
||||
print(args)
|
||||
|
||||
# if only writing vocab to file
|
||||
if args.vocab_only:
|
||||
|
||||
fname_model = f"{dir_model}/consolidated.00.pth"
|
||||
fname_out = f"{dir_model}/ggml-vocab.bin"
|
||||
|
||||
print(f"Extracting only the vocab from '{fname_model}'\n")
|
||||
|
||||
|
||||
with open(fname_out, "wb") as fout:
|
||||
write_header(fout, hparams, ftype)
|
||||
write_tokens(fout, tokenizer)
|
||||
|
||||
|
||||
print(f"Done. Output file: {fname_out}\n")
|
||||
|
||||
return
|
||||
|
||||
n_parts = get_n_parts(hparams["dim"])
|
||||
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin"
|
||||
|
||||
for p in range(n_parts):
|
||||
# we output a single file for ggml
|
||||
with open(fname_out, "wb") as fout:
|
||||
write_header(fout, hparams, ftype)
|
||||
write_tokens(fout, tokenizer)
|
||||
offset_of_tensors = fout.tell()
|
||||
# the tensors we load could be split across multiple files
|
||||
for part_id in range(n_parts):
|
||||
fout.seek(offset_of_tensors)
|
||||
print(f"Processing part {part_id+1} of {n_parts}\n")
|
||||
fname_model = f"{dir_model}/consolidated.0{part_id}.pth"
|
||||
model = torch.load(fname_model, map_location="cpu")
|
||||
process_and_write_variables(fout, model, ftype, part_id, n_parts)
|
||||
del model
|
||||
|
||||
print(f"Processing part {p+1} of {n_parts}\n")
|
||||
|
||||
fname_model = f"{dir_model}/consolidated.0{p}.pth"
|
||||
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
|
||||
|
||||
model = torch.load(fname_model, map_location="cpu")
|
||||
|
||||
with open(fname_out, "wb") as fout:
|
||||
write_header(fout, hparams, ftype)
|
||||
write_tokens(fout, tokenizer)
|
||||
process_and_write_variables(fout, model, ftype)
|
||||
|
||||
del model
|
||||
|
||||
print(f"Done. Output file: {fname_out}, (part {p})\n")
|
||||
print(f"Done. Output file: {fname_out}\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -19,7 +19,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// needed to initialize f16 tables
|
||||
{
|
||||
struct ggml_init_params params = { 0, NULL };
|
||||
struct ggml_init_params params = { 0, NULL, false };
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
|
29
expose.cpp
29
expose.cpp
|
@ -68,7 +68,8 @@ extern "C" {
|
|||
char text[16384]; //16kb should be enough for any response
|
||||
};
|
||||
|
||||
bool legacy_format = false;
|
||||
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
||||
int file_format = 0;
|
||||
llama_context_params ctx_params;
|
||||
gpt_params params;
|
||||
int n_past = 0;
|
||||
|
@ -95,22 +96,28 @@ extern "C" {
|
|||
ctx_params.f16_kv = inputs.f16_kv;
|
||||
ctx_params.logits_all = false;
|
||||
|
||||
ctx = llama_init_from_file(model.c_str(), ctx_params);
|
||||
|
||||
file_format = check_file_format(model.c_str());
|
||||
printf("\nFile format detected: (ver %d)\n",file_format);
|
||||
|
||||
if(file_format==1 || file_format==2)
|
||||
{
|
||||
ctx = legacy_llama_init_from_file(model.c_str(), ctx_params);
|
||||
}
|
||||
else
|
||||
{
|
||||
ctx = llama_init_from_file(model.c_str(), ctx_params);
|
||||
}
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
//return val: 0=fail, 1=newformat, 2=legacy
|
||||
int fileformat = check_file_format(model.c_str());
|
||||
|
||||
legacy_format = (fileformat==1?true:false);
|
||||
if(legacy_format)
|
||||
if(file_format<3)
|
||||
{
|
||||
printf("\n---\nWarning: Your model is using an OUTDATED format. Please reconvert it for better results!\n");
|
||||
printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n",file_format);
|
||||
}
|
||||
|
||||
|
||||
//determine mem per token
|
||||
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
|
||||
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
|
||||
|
@ -149,7 +156,7 @@ extern "C" {
|
|||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> embd_inp;
|
||||
if(legacy_format)
|
||||
if(file_format==1)
|
||||
{
|
||||
embd_inp = ::legacy_llama_tokenize(ctx, params.prompt, true);
|
||||
}else{
|
||||
|
|
84
extra.cpp
84
extra.cpp
|
@ -18,7 +18,7 @@
|
|||
#include <alloca.h>
|
||||
#endif
|
||||
|
||||
//return val: 0=fail, 1=legacy, 2=newformat
|
||||
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
||||
int check_file_format(const std::string & fname)
|
||||
{
|
||||
std::vector<char> f_buf(1024*1024);
|
||||
|
@ -33,16 +33,94 @@
|
|||
int fileformat = 0;
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
|
||||
if (magic == 0x67676d6c) { //v1 format ggml, alpaca
|
||||
fileformat = 1;
|
||||
}else{
|
||||
}
|
||||
else if(magic == 0x67676d66) //v2 format ggmf
|
||||
{
|
||||
fileformat = 2;
|
||||
}
|
||||
else if(magic == 0x67676a74) //v3 format ggjt
|
||||
{
|
||||
fileformat = 3; //ggjt by default
|
||||
}
|
||||
fin.close();
|
||||
|
||||
return fileformat;
|
||||
}
|
||||
|
||||
//freeze all the configurations for model loading for v1 and v2 formats
|
||||
struct llama_context * legacy_llama_init_from_file(const char * path_model, struct llama_context_params params)
|
||||
{
|
||||
ggml_time_init();
|
||||
|
||||
llama_context * ctx = new llama_context;
|
||||
|
||||
if (params.seed <= 0) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
ctx->rng = std::mt19937(params.seed);
|
||||
ctx->logits_all = params.logits_all;
|
||||
|
||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
if (!legacy_llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
|
||||
params.vocab_only, params.progress_callback,
|
||||
params.progress_callback_user_data)) {
|
||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (params.use_mlock) {
|
||||
char *err;
|
||||
if (!ggml_mlock(ctx->model.ctx,
|
||||
ctx->model.mm_addr,
|
||||
ctx->model.mm_length,
|
||||
&err)) {
|
||||
fprintf(stderr, "%s\n", err);
|
||||
free(err);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// reserve memory for context buffers
|
||||
{
|
||||
if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
|
||||
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
{
|
||||
const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
|
||||
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
// resized during inference
|
||||
if (params.logits_all) {
|
||||
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
||||
} else {
|
||||
ctx->logits.reserve(hparams.n_ctx);
|
||||
}
|
||||
|
||||
if (params.embedding){
|
||||
ctx->embedding.resize(hparams.n_embd);
|
||||
}
|
||||
|
||||
ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
|
||||
|
||||
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
|
||||
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
//legacy llama model format v1 and v2 loader. there is a lot of duplicate code,
|
||||
//but it may be better to freeze it as such rather than risk tiny breaking changes
|
||||
static bool legacy_llama_model_load(
|
||||
|
|
3
extra.h
3
extra.h
|
@ -18,4 +18,5 @@
|
|||
int check_file_format(const std::string & fname);
|
||||
|
||||
std::vector<llama_token> legacy_llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
|
||||
static bool legacy_llama_model_load(const std::string & fname, llama_context & lctx, int n_ctx, int n_parts, ggml_type memory_type, bool vocab_only, llama_progress_callback progress_callback, void *progress_callback_user_data);
|
||||
static bool legacy_llama_model_load(const std::string & fname, llama_context & lctx, int n_ctx, int n_parts, ggml_type memory_type, bool vocab_only, llama_progress_callback progress_callback, void *progress_callback_user_data);
|
||||
struct llama_context * legacy_llama_init_from_file(const char * path_model, struct llama_context_params params);
|
55
ggml.c
55
ggml.c
|
@ -1038,8 +1038,8 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|||
const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
|
||||
|
||||
// convert to 2x uint16x8_t
|
||||
const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq));
|
||||
const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq));
|
||||
const uint16x8_t vi_0 = vmovl_u8(vget_low_u8 (vq));
|
||||
const uint16x8_t vi_1 = vmovl_u8(vget_high_u8(vq));
|
||||
|
||||
// convert to 4x float32x4_t
|
||||
const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
|
||||
|
@ -1297,7 +1297,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
|||
_mm256_storeu_ps(arr, y);
|
||||
|
||||
for (int i = 0; i < 8; i++)
|
||||
x[i] = GGML_FP16_TO_FP32(arr[i]);
|
||||
x[i] = GGML_FP32_TO_FP16(arr[i]);
|
||||
}
|
||||
#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
|
||||
#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
|
||||
|
@ -1829,7 +1829,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|||
|
||||
const int superblock_size = 8;
|
||||
const int superblock_count = nb / superblock_size;
|
||||
const int remainder = nb % superblock_size;
|
||||
|
||||
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
|
||||
int i = superblock_ix * superblock_size;
|
||||
|
@ -2530,8 +2529,9 @@ struct ggml_context {
|
|||
void * mem_buffer;
|
||||
bool mem_buffer_owned;
|
||||
bool mem_buffer_mlocked;
|
||||
bool no_alloc;
|
||||
|
||||
int n_objects;
|
||||
int n_objects;
|
||||
|
||||
struct ggml_object * objects_begin;
|
||||
struct ggml_object * objects_end;
|
||||
|
@ -2816,6 +2816,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
|
||||
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
|
||||
/*.mem_buffer_mlocked =*/ false,
|
||||
/*.no_alloc =*/ params.no_alloc,
|
||||
/*.n_objects =*/ 0,
|
||||
/*.objects_begin =*/ NULL,
|
||||
/*.objects_end =*/ NULL,
|
||||
|
@ -2883,36 +2884,47 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
|
|||
return result;
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
||||
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
|
||||
#else
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
|
||||
#endif
|
||||
|
||||
bool ggml_mlock_supported(void) {
|
||||
return GGML_MLOCK_SUPPORT;
|
||||
}
|
||||
|
||||
bool ggml_mlock(
|
||||
struct ggml_context * ctx,
|
||||
const void *opt_extra_addr,
|
||||
size_t opt_extra_len,
|
||||
char **err_p) {
|
||||
// TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
|
||||
#if GGML_MLOCK_SUPPORT
|
||||
#ifdef __APPLE__
|
||||
#define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \
|
||||
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l)."
|
||||
#else
|
||||
#define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)."
|
||||
#endif
|
||||
bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
|
||||
if (ctx->mem_buffer_mlocked) {
|
||||
return true;
|
||||
}
|
||||
if (mlock(ctx->mem_buffer, ctx->mem_size)) {
|
||||
int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
|
||||
ctx->mem_size, strerror(errno));
|
||||
GGML_ASSERT(ret >= 0);
|
||||
if (mlock(ctx->mem_buffer, ctx->mem_size) ||
|
||||
(opt_extra_len &&
|
||||
mlock(opt_extra_addr, opt_extra_len))) {
|
||||
if ((*err_p = malloc(1024))) {
|
||||
snprintf(*err_p, 1024,
|
||||
"failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
|
||||
ctx->mem_size + opt_extra_len,
|
||||
strerror(errno));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
ctx->mem_buffer_mlocked = true;
|
||||
return true;
|
||||
}
|
||||
#else // GGML_MLOCK_SUPPORT
|
||||
bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
|
||||
*err_p = strdup("can't mlock because it's not supported on this system");
|
||||
return false;
|
||||
}
|
||||
#endif // GGML_MLOCK_SUPPORT
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
@ -2931,7 +2943,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|||
|
||||
size_t size_needed = 0;
|
||||
|
||||
if (data == NULL) {
|
||||
if (data == NULL && !ctx->no_alloc) {
|
||||
size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
|
||||
for (int i = 1; i < n_dims; i++) {
|
||||
size_needed *= ne[i];
|
||||
|
@ -3015,7 +3027,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|||
/*.perf_runs =*/ 0,
|
||||
/*.perf_cycles =*/ 0,
|
||||
/*.perf_time_us =*/ 0,
|
||||
/*.data =*/ data == NULL ? (void *)(result + 1) : data,
|
||||
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
|
||||
/*.pad =*/ { 0 },
|
||||
};
|
||||
|
||||
|
@ -10278,6 +10290,7 @@ enum ggml_opt_result ggml_opt(
|
|||
struct ggml_init_params params_ctx = {
|
||||
.mem_size = 16*1024*1024,
|
||||
.mem_buffer = NULL,
|
||||
.no_alloc = false,
|
||||
};
|
||||
|
||||
ctx = ggml_init(params_ctx);
|
||||
|
|
7
ggml.h
7
ggml.h
|
@ -316,6 +316,7 @@ struct ggml_init_params {
|
|||
// memory pool
|
||||
size_t mem_size; // bytes
|
||||
void * mem_buffer; // if NULL, memory will be allocated internally
|
||||
bool no_alloc; // don't allocate memory for the tensor data
|
||||
};
|
||||
|
||||
void ggml_time_init(void); // call this once at the beginning of the program
|
||||
|
@ -344,7 +345,11 @@ size_t ggml_used_mem(const struct ggml_context * ctx);
|
|||
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
|
||||
|
||||
bool ggml_mlock_supported(void);
|
||||
bool ggml_mlock(struct ggml_context * ctx, char ** err_p);
|
||||
bool ggml_mlock(
|
||||
struct ggml_context * ctx,
|
||||
const void *opt_extra_addr,
|
||||
size_t opt_extra_len,
|
||||
char **err_p);
|
||||
|
||||
struct ggml_tensor * ggml_new_tensor(
|
||||
struct ggml_context * ctx,
|
||||
|
|
482
llama.cpp
482
llama.cpp
|
@ -12,6 +12,19 @@
|
|||
#include <cassert>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#include <Windows.h>
|
||||
#else
|
||||
#include <sys/types.h>
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
|
||||
#define Min(X, Y) ((Y) > (X) ? (X) : (Y))
|
||||
#define Max(X, Y) ((Y) < (X) ? (X) : (Y))
|
||||
|
||||
#define LLAMA_USE_SCRATCH
|
||||
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
||||
|
||||
|
@ -142,6 +155,10 @@ struct llama_model {
|
|||
// the model memory buffer
|
||||
std::vector<uint8_t> buf;
|
||||
|
||||
// model memory mapped file
|
||||
void * mm_addr = NULL;
|
||||
uint64_t mm_length = 0;
|
||||
|
||||
// tensors
|
||||
int n_loaded;
|
||||
std::unordered_map<std::string, struct ggml_tensor *> tensors;
|
||||
|
@ -165,6 +182,7 @@ struct llama_context {
|
|||
|
||||
int64_t t_load_us = 0;
|
||||
int64_t t_start_us = 0;
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_eval_us = 0;
|
||||
|
@ -206,7 +224,7 @@ struct llama_context {
|
|||
}
|
||||
|
||||
if (buf_last >= 0) {
|
||||
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
||||
buf_max_size[buf_last] = Max(buf_max_size[buf_last], last_size);
|
||||
}
|
||||
|
||||
buf_last = i;
|
||||
|
@ -246,6 +264,7 @@ static bool kv_cache_init(
|
|||
struct ggml_init_params params;
|
||||
params.mem_size = cache.buf.size();
|
||||
params.mem_buffer = cache.buf.data();
|
||||
params.no_alloc = false;
|
||||
|
||||
cache.ctx = ggml_init(params);
|
||||
|
||||
|
@ -288,6 +307,58 @@ struct llama_context_params llama_context_default_params() {
|
|||
// model loading
|
||||
//
|
||||
|
||||
static void *mmap_file(const char *fname, uint64_t *mm_length) {
|
||||
#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES)
|
||||
HANDLE hFile = CreateFileA(fname,
|
||||
GENERIC_READ,
|
||||
FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
|
||||
NULL,
|
||||
OPEN_EXISTING,
|
||||
FILE_ATTRIBUTE_NORMAL | FILE_ATTRIBUTE_NOT_CONTENT_INDEXED,
|
||||
NULL);
|
||||
if (hFile == INVALID_HANDLE_VALUE) return 0;
|
||||
LARGE_INTEGER fileSize;
|
||||
fileSize.QuadPart = -1;
|
||||
GetFileSizeEx(hFile, &fileSize);
|
||||
int64_t length = fileSize.QuadPart;
|
||||
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
CloseHandle(hFile);
|
||||
if (!hMapping) return 0;
|
||||
void *addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||
CloseHandle(hMapping);
|
||||
if (!addr) return 0;
|
||||
#else
|
||||
int fd = open(fname, O_RDONLY);
|
||||
if (fd == -1) return 0;
|
||||
int64_t length = lseek(fd, 0, SEEK_END);
|
||||
void *addr = mmap(NULL, length, PROT_READ, MAP_SHARED, fd, 0);
|
||||
close(fd);
|
||||
if (addr == MAP_FAILED) return 0;
|
||||
#endif
|
||||
*mm_length = length;
|
||||
return addr;
|
||||
}
|
||||
|
||||
static void munmap_file(void * addr, size_t length) {
|
||||
#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES)
|
||||
UnmapViewOfFile(addr);
|
||||
#else
|
||||
munmap(addr, length);
|
||||
#endif
|
||||
}
|
||||
|
||||
static bool report_bad_magic(const char *path, uint32_t got, uint32_t want) {
|
||||
fprintf(stderr,
|
||||
"%s: invalid model file (bad magic [got %#x want %#x])\n"
|
||||
"\tyou most likely need to regenerate your ggml files\n"
|
||||
"\tthe benefit is you'll get 10-100x faster load times\n"
|
||||
"\tsee https://github.com/ggerganov/llama.cpp/issues/91\n"
|
||||
"\tuse convert-pth-to-ggml.py to regenerate from original pth\n"
|
||||
"\tuse migrate-ggml-2023-03-30-pr613.py if you deleted originals\n",
|
||||
path, got, want);
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_model_load(
|
||||
const std::string & fname,
|
||||
llama_context & lctx,
|
||||
|
@ -299,23 +370,24 @@ static bool llama_model_load(
|
|||
void *progress_callback_user_data) {
|
||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
lctx.t_start_us = t_start_us;
|
||||
|
||||
std::vector<char> f_buf(1024*1024);
|
||||
lctx.t_start_us = ggml_time_us();
|
||||
|
||||
auto & model = lctx.model;
|
||||
auto & vocab = lctx.vocab;
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool legacy_file_format = false;
|
||||
std::vector<char> f_buf(1024*1024);
|
||||
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
||||
|
||||
fin.seekg(0, fin.end);
|
||||
const size_t file_size = fin.tellg();
|
||||
fin.seekg(0);
|
||||
|
||||
// verify magic
|
||||
{
|
||||
uint32_t magic;
|
||||
|
@ -323,14 +395,11 @@ static bool llama_model_load(
|
|||
if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files or convert them with convert-unversioned-ggml-to-ggml.py!)\n",
|
||||
__func__, fname.c_str());
|
||||
legacy_file_format = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (magic != LLAMA_FILE_MAGIC) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
if (magic != LLAMA_FILE_MAGIC) {
|
||||
return report_bad_magic(fname.c_str(), magic, LLAMA_FILE_MAGIC);
|
||||
}
|
||||
|
||||
uint32_t format_version;
|
||||
fin.read((char *) &format_version, sizeof(format_version));
|
||||
|
@ -340,7 +409,6 @@ static bool llama_model_load(
|
|||
__func__, fname.c_str(), format_version, LLAMA_FILE_VERSION);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int n_ff = 0;
|
||||
|
@ -421,10 +489,7 @@ static bool llama_model_load(
|
|||
}
|
||||
|
||||
float score;
|
||||
if(!legacy_file_format)
|
||||
{
|
||||
fin.read((char *) &score, sizeof(score));
|
||||
}
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
|
||||
|
@ -456,43 +521,24 @@ static bool llama_model_load(
|
|||
}
|
||||
}
|
||||
|
||||
// map model into memory
|
||||
char *mm_addr = NULL;
|
||||
model.mm_addr = mmap_file(fname.c_str(), &model.mm_length);
|
||||
if (model.mm_addr == NULL) {
|
||||
fprintf(stderr, "%s: failed to mmap '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
mm_addr = (char *)model.mm_addr;
|
||||
fprintf(stderr, "%s: ggml map size = %6.2f MB\n", __func__, model.mm_length/(1024.0*1024.0));
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const auto &hparams = model.hparams;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings
|
||||
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
|
||||
|
||||
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
|
||||
|
||||
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
|
||||
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
|
||||
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v
|
||||
|
||||
ctx_size += (5 + 10*n_layer)*256; // object overhead
|
||||
|
||||
fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0);
|
||||
}
|
||||
|
||||
// print memory requirements
|
||||
|
@ -502,6 +548,7 @@ static bool llama_model_load(
|
|||
// this is the total memory required to run the inference
|
||||
const size_t mem_required =
|
||||
ctx_size +
|
||||
model.mm_length +
|
||||
MEM_REQ_SCRATCH0.at(model.type) +
|
||||
MEM_REQ_SCRATCH1.at(model.type) +
|
||||
MEM_REQ_EVAL.at (model.type);
|
||||
|
@ -521,6 +568,7 @@ static bool llama_model_load(
|
|||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ lctx.model.buf.size(),
|
||||
/*.mem_buffer =*/ lctx.model.buf.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
|
@ -583,234 +631,106 @@ static bool llama_model_load(
|
|||
}
|
||||
}
|
||||
|
||||
const size_t file_offset = fin.tellg();
|
||||
|
||||
fin.close();
|
||||
|
||||
std::vector<uint8_t> tmp;
|
||||
|
||||
if (progress_callback) {
|
||||
progress_callback(0.0, progress_callback_user_data);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_parts; ++i) {
|
||||
const int part_id = i;
|
||||
//const int part_id = n_parts - i - 1;
|
||||
fprintf(stderr, "%s: loading tensors from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
std::string fname_part = fname;
|
||||
if (i > 0) {
|
||||
fname_part += "." + std::to_string(i);
|
||||
}
|
||||
// load weights
|
||||
{
|
||||
size_t total_size = 0;
|
||||
model.n_loaded = 0;
|
||||
|
||||
fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str());
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
|
||||
fin = std::ifstream(fname_part, std::ios::binary);
|
||||
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
|
||||
fin.seekg(0, fin.end);
|
||||
const size_t file_size = fin.tellg();
|
||||
|
||||
fin.seekg(file_offset);
|
||||
|
||||
// load weights
|
||||
{
|
||||
size_t total_size = 0;
|
||||
|
||||
model.n_loaded = 0;
|
||||
|
||||
fprintf(stderr, "%s: ", __func__);
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
// split_type = 0: split by columns
|
||||
// split_type = 1: split by rows
|
||||
int split_type = 0;
|
||||
|
||||
// split_type = 0:
|
||||
// regex:
|
||||
// - tok_embeddings.*
|
||||
// - layers.*.attention.wo.weight
|
||||
// - layers.*.feed_forward.w2.weight
|
||||
|
||||
// split_type = 1:
|
||||
// regex:
|
||||
// - output.*
|
||||
// - layers.*.attention.wq.weight
|
||||
// - layers.*.attention.wk.weight
|
||||
// - layers.*.attention.wv.weight
|
||||
// - layers.*.feed_forward.w1.weight
|
||||
// - layers.*.feed_forward.w3.weight
|
||||
if (name.find("tok_embeddings") != std::string::npos) {
|
||||
split_type = 0;
|
||||
} else if (name.find("layers") != std::string::npos) {
|
||||
if (name.find("attention.wo.weight") != std::string::npos) {
|
||||
split_type = 0;
|
||||
} else if (name.find("feed_forward.w2.weight") != std::string::npos) {
|
||||
split_type = 0;
|
||||
} else {
|
||||
split_type = 1;
|
||||
}
|
||||
} else if (name.find("output") != std::string::npos) {
|
||||
split_type = 1;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
|
||||
if (n_dims == 1) {
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (ggml_nelements(tensor)/n_parts != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_dims == 1) {
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (split_type == 0) {
|
||||
if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (0) {
|
||||
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
||||
fprintf(stderr, "%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
|
||||
}
|
||||
|
||||
size_t bpe = 0;
|
||||
|
||||
switch (ftype) {
|
||||
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
|
||||
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
|
||||
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
|
||||
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
if (n_dims == 1 || n_parts == 1) {
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (part_id == 0) {
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
} else {
|
||||
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
|
||||
}
|
||||
|
||||
total_size += ggml_nbytes(tensor);
|
||||
} else {
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (split_type == 0) {
|
||||
const int np0 = ne[0];
|
||||
|
||||
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
|
||||
assert(row_size == tensor->nb[1]);
|
||||
|
||||
for (int i1 = 0; i1 < ne[1]; ++i1) {
|
||||
const size_t offset_row = i1*row_size;
|
||||
const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
|
||||
fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
|
||||
}
|
||||
} else {
|
||||
const int np1 = ne[1];
|
||||
|
||||
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
|
||||
|
||||
for (int i1 = 0; i1 < ne[1]; ++i1) {
|
||||
const size_t offset_row = (i1 + part_id*np1)*row_size;
|
||||
fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
|
||||
}
|
||||
}
|
||||
|
||||
total_size += ggml_nbytes(tensor)/n_parts;
|
||||
}
|
||||
|
||||
//fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
model.n_loaded++;
|
||||
|
||||
// progress
|
||||
if (progress_callback) {
|
||||
float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset);
|
||||
float current_progress = (float(i) + current_file_progress) / float(n_parts);
|
||||
progress_callback(current_progress, progress_callback_user_data);
|
||||
}
|
||||
if (model.n_loaded % 8 == 0) {
|
||||
fprintf(stderr, ".");
|
||||
fflush(stderr);
|
||||
}
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, " done\n");
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
|
||||
if (model.n_loaded == 0) {
|
||||
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
||||
} else if (model.n_loaded != (int) model.tensors.size()) {
|
||||
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
if (0) {
|
||||
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
||||
fprintf(stderr, "%24s - [%5d, %5d], type = %6s\n", name.data(), ne[0], ne[1], ftype_str[ftype]);
|
||||
}
|
||||
|
||||
switch (ftype) {
|
||||
case 0: // f32
|
||||
case 1: // f16
|
||||
break;
|
||||
case 2: // q4_0
|
||||
case 3: // q4_1
|
||||
assert(ne[0] % 64 == 0);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
|
||||
return false;
|
||||
};
|
||||
|
||||
// load the tensor data into memory without copying or reading it
|
||||
size_t offset = fin.tellg();
|
||||
size_t tensor_data_size = ggml_nbytes(tensor);
|
||||
offset = (offset + 31) & -32;
|
||||
tensor->data = mm_addr + offset;
|
||||
fin.seekg(offset + tensor_data_size);
|
||||
total_size += tensor_data_size;
|
||||
model.n_loaded++;
|
||||
|
||||
// progress
|
||||
if (progress_callback) {
|
||||
double current_progress = size_t(fin.tellg()) / double(file_size);
|
||||
progress_callback(current_progress, progress_callback_user_data);
|
||||
}
|
||||
}
|
||||
|
||||
fin.close();
|
||||
|
||||
fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
|
||||
if (model.n_loaded == 0) {
|
||||
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
||||
} else if (model.n_loaded != (int) model.tensors.size()) {
|
||||
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
lctx.t_load_us = ggml_time_us() - t_start_us;
|
||||
// loading time will be recalculate after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
lctx.t_load_us = ggml_time_us() - lctx.t_start_us;
|
||||
|
||||
if (progress_callback) {
|
||||
progress_callback(1.0, progress_callback_user_data);
|
||||
|
@ -856,6 +776,7 @@ static bool llama_eval_internal(
|
|||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size(),
|
||||
/*.mem_buffer =*/ buf_compute.data(),
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
@ -1133,7 +1054,7 @@ struct llama_tokenizer {
|
|||
size_t offs = 0;
|
||||
while (offs < text.size()) {
|
||||
llama_sp_symbol sym;
|
||||
size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
|
||||
size_t char_len = Min(text.size() - offs, utf8_len(text[offs]));
|
||||
sym.text = text.c_str() + offs;
|
||||
sym.n = char_len;
|
||||
offs += char_len;
|
||||
|
@ -1298,7 +1219,7 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
|||
|
||||
float maxl = -std::numeric_limits<float>::infinity();
|
||||
for (const auto & kv : logits_id) {
|
||||
maxl = std::max(maxl, kv.first);
|
||||
maxl = Max(maxl, kv.first);
|
||||
}
|
||||
|
||||
// compute probs for the top k tokens
|
||||
|
@ -1392,8 +1313,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
return false;
|
||||
}
|
||||
if (magic != LLAMA_FILE_MAGIC) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
return report_bad_magic(fname_inp.c_str(), magic, LLAMA_FILE_MAGIC);
|
||||
}
|
||||
|
||||
fout.write((char *) &magic, sizeof(magic));
|
||||
|
@ -1459,8 +1379,8 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
fout.write((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
finp.read ((char *) word.data(), len);
|
||||
fout.write((char *) word.data(), len);
|
||||
finp.read ((char *) &word[0], len);
|
||||
fout.write((char *) &word[0], len);
|
||||
|
||||
float score;
|
||||
finp.read ((char *) &score, sizeof(score));
|
||||
|
@ -1510,6 +1430,13 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
std::string name(length, 0);
|
||||
finp.read (&name[0], length);
|
||||
|
||||
{
|
||||
// ensure tensor data is aligned
|
||||
uint64_t offset = finp.tellg();
|
||||
offset = (offset + 31) & -32;
|
||||
finp.seekg(offset);
|
||||
}
|
||||
|
||||
{
|
||||
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
||||
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
|
||||
|
@ -1565,6 +1492,13 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
}
|
||||
fout.write(&name[0], length);
|
||||
|
||||
{
|
||||
// ensure tensor data is aligned
|
||||
uint64_t offset = fout.tellp();
|
||||
offset = (offset + 31) & -32;
|
||||
fout.seekp(offset);
|
||||
}
|
||||
|
||||
if (quantize) {
|
||||
printf("quantizing .. ");
|
||||
work.resize(nelements); // for quantization
|
||||
|
@ -1662,7 +1596,10 @@ struct llama_context * llama_init_from_file(
|
|||
|
||||
if (params.use_mlock) {
|
||||
char *err;
|
||||
if (!ggml_mlock(ctx->model.ctx, &err)) {
|
||||
if (!ggml_mlock(ctx->model.ctx,
|
||||
ctx->model.mm_addr,
|
||||
ctx->model.mm_length,
|
||||
&err)) {
|
||||
fprintf(stderr, "%s\n", err);
|
||||
free(err);
|
||||
llama_free(ctx);
|
||||
|
@ -1712,6 +1649,10 @@ void llama_free(struct llama_context * ctx) {
|
|||
ggml_free(ctx->model.ctx);
|
||||
}
|
||||
|
||||
if (ctx->model.mm_addr) {
|
||||
munmap_file(ctx->model.mm_addr, ctx->model.mm_length);
|
||||
}
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
|
@ -1737,7 +1678,11 @@ int llama_eval(
|
|||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
if (!ctx->has_evaluated_once) {
|
||||
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
|
||||
ctx->has_evaluated_once = true;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -1830,9 +1775,9 @@ llama_token llama_sample_top_p_top_k(
|
|||
void llama_print_timings(struct llama_context * ctx) {
|
||||
const int64_t t_end_us = ggml_time_us();
|
||||
|
||||
const int32_t n_sample = std::max(1, ctx->n_sample);
|
||||
const int32_t n_eval = std::max(1, ctx->n_eval);
|
||||
const int32_t n_p_eval = std::max(1, ctx->n_p_eval);
|
||||
const int32_t n_sample = Max(1, ctx->n_sample);
|
||||
const int32_t n_eval = Max(1, ctx->n_eval);
|
||||
const int32_t n_p_eval = Max(1, ctx->n_p_eval);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
|
||||
|
@ -1844,7 +1789,6 @@ void llama_print_timings(struct llama_context * ctx) {
|
|||
|
||||
void llama_reset_timings(struct llama_context * ctx) {
|
||||
ctx->t_start_us = ggml_time_us();
|
||||
|
||||
ctx->t_sample_us = ctx->n_sample = 0;
|
||||
ctx->t_eval_us = ctx->n_eval = 0;
|
||||
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
||||
|
|
2
llama.h
2
llama.h
|
@ -20,7 +20,7 @@
|
|||
#endif
|
||||
|
||||
#define LLAMA_FILE_VERSION 1
|
||||
#define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex
|
||||
#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
|
||||
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
BIN
llamacpp.dll
BIN
llamacpp.dll
Binary file not shown.
Binary file not shown.
BIN
main.exe
BIN
main.exe
Binary file not shown.
313
migrate-ggml-2023-03-30-pr613.py
Normal file
313
migrate-ggml-2023-03-30-pr613.py
Normal file
|
@ -0,0 +1,313 @@
|
|||
# Migrate ggml file(s) with ggmf magic to ggml file with ggjt magic
|
||||
#
|
||||
# We caused a breaking change to the file format on 2023-03-30 in:
|
||||
# https://github.com/ggerganov/llama.cpp/pull/613
|
||||
#
|
||||
# (1) If you still have the Meta LLaMA .pth files, then close this
|
||||
# file now; you can just run `convert-pth-to-ggml.py` again to
|
||||
# migrate to the new format. The tool is easier to use too. It
|
||||
# isn't necessary anymore to manage split output files because
|
||||
# the new format always combines things into a single file.
|
||||
#
|
||||
# (2) If you deleted the Meta LLaMA .pth files due to save on disk
|
||||
# space, then this tool is intended to help you. Please check
|
||||
# out the instructions below.
|
||||
#
|
||||
# USAGE
|
||||
#
|
||||
# python migrate-ggml-2023-03-30-pr613.py INPUT OUTPUT
|
||||
#
|
||||
# PREREQUISITES
|
||||
#
|
||||
# pip install numpy
|
||||
# cd llama.cpp
|
||||
# make -j4
|
||||
#
|
||||
# EXAMPLE (7B MODEL)
|
||||
#
|
||||
# # you can replace all the 'f16' with 'q4_0' if you're using quantized weights
|
||||
# python migrate-ggml-2023-03-30-pr613.py models/7B/ggml-model-f16.bin models/7B/ggml-model-f16-ggjt.bin
|
||||
#
|
||||
# # check that it works
|
||||
# ./main -m models/7B/ggml-model-f16-ggjt.bin -p 'Question: Do you love me?'
|
||||
#
|
||||
# # you can delete the old files
|
||||
# rm -f models/7B/ggml-model-f16.bin
|
||||
# mv models/7B/ggml-model-f16-ggjt.bin models/7B/ggml-model-f16.bin
|
||||
#
|
||||
# EXAMPLE (13B MODEL)
|
||||
#
|
||||
# # you can replace all the 'f16' with 'q4_0' if you're using quantized weights
|
||||
# python migrate-ggml-2023-03-30-pr613.py models/13B/ggml-model-f16.bin models/13B/ggml-model-f16-ggjt.bin
|
||||
#
|
||||
# # check that it works
|
||||
# ./main -m models/13B/ggml-model-f16-ggjt.bin -p 'Question: Do you love me?'
|
||||
#
|
||||
# # you can delete the old files
|
||||
# rm -f models/13B/ggml-model-f16.bin*
|
||||
# mv models/13B/ggml-model-f16-ggjt.bin models/13B/ggml-model-f16.bin
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import struct
|
||||
import numpy as np
|
||||
|
||||
QK = 32
|
||||
|
||||
GGML_TYPE_Q4_0 = 0
|
||||
GGML_TYPE_Q4_1 = 1
|
||||
GGML_TYPE_I8 = 2
|
||||
GGML_TYPE_I16 = 3
|
||||
GGML_TYPE_I32 = 4
|
||||
GGML_TYPE_F16 = 5
|
||||
GGML_TYPE_F32 = 6
|
||||
|
||||
WTYPE_NAMES = {
|
||||
0: "F32",
|
||||
1: "F16",
|
||||
2: "Q4_0",
|
||||
3: "Q4_1",
|
||||
}
|
||||
|
||||
WTYPES = {
|
||||
0: GGML_TYPE_F32,
|
||||
1: GGML_TYPE_F16,
|
||||
2: GGML_TYPE_Q4_0,
|
||||
3: GGML_TYPE_Q4_1,
|
||||
}
|
||||
|
||||
GGML_BLCK_SIZE = {
|
||||
GGML_TYPE_Q4_0: QK,
|
||||
GGML_TYPE_Q4_1: QK,
|
||||
GGML_TYPE_I8: 1,
|
||||
GGML_TYPE_I16: 1,
|
||||
GGML_TYPE_I32: 1,
|
||||
GGML_TYPE_F16: 1,
|
||||
GGML_TYPE_F32: 1,
|
||||
}
|
||||
|
||||
GGML_TYPE_SIZE = {
|
||||
GGML_TYPE_Q4_0: 4 + QK//2,
|
||||
GGML_TYPE_Q4_1: 4*2 + QK//2,
|
||||
GGML_TYPE_I8: 1,
|
||||
GGML_TYPE_I16: 2,
|
||||
GGML_TYPE_I32: 4,
|
||||
GGML_TYPE_F16: 2,
|
||||
GGML_TYPE_F32: 4,
|
||||
}
|
||||
|
||||
HPARAMS = [
|
||||
'magic', # int32
|
||||
'version', # int32
|
||||
'n_vocab', # int32
|
||||
'n_embd', # int32
|
||||
'n_mult', # int32
|
||||
'n_head', # int32
|
||||
'n_layer', # int32
|
||||
'n_rot', # int32
|
||||
'f16', # int32
|
||||
]
|
||||
|
||||
def read_hparams(fin):
|
||||
struct_fmt = "i" * len(HPARAMS)
|
||||
struct_size = struct.calcsize(struct_fmt)
|
||||
buf = fin.read(struct_size)
|
||||
ints = struct.unpack(struct_fmt, buf)
|
||||
hparams = dict(zip(HPARAMS, ints))
|
||||
return hparams
|
||||
|
||||
def write_hparams(fout, hparams):
|
||||
struct_fmt = "i" * len(HPARAMS)
|
||||
struct_size = struct.calcsize(struct_fmt)
|
||||
ints = [hparams[h] for h in HPARAMS]
|
||||
fout.write(struct.pack(struct_fmt, *ints))
|
||||
|
||||
def read_tokens(fin, hparams):
|
||||
tokens = []
|
||||
for i in range(hparams['n_vocab']):
|
||||
len_b = fin.read(4)
|
||||
(length,) = struct.unpack("i", len_b)
|
||||
word = fin.read(length)
|
||||
score_b = fin.read(4)
|
||||
(score,) = struct.unpack("f", score_b)
|
||||
tokens.append((word, score))
|
||||
return tokens
|
||||
|
||||
def write_tokens(fout, tokens):
|
||||
for word, score in tokens:
|
||||
fout.write(struct.pack("i", len(word)))
|
||||
fout.write(word)
|
||||
fout.write(struct.pack("f", score))
|
||||
|
||||
def ggml_nelements(shape):
|
||||
r = 1
|
||||
for i in shape:
|
||||
r *= i
|
||||
return r
|
||||
|
||||
def ggml_nbytes(shape, ftype):
|
||||
x = ggml_nelements(shape)
|
||||
t = WTYPES[ftype]
|
||||
x *= GGML_TYPE_SIZE[t]
|
||||
x //= GGML_BLCK_SIZE[t]
|
||||
return x
|
||||
|
||||
def copy_tensors(fin, fout, part_id, n_parts):
|
||||
while True:
|
||||
|
||||
b = fin.read(4)
|
||||
if not b: break
|
||||
(n_dims,) = struct.unpack("i", b)
|
||||
b = fin.read(4)
|
||||
(length,) = struct.unpack("i", b)
|
||||
b = fin.read(4)
|
||||
(ftype,) = struct.unpack("i", b)
|
||||
|
||||
assert n_dims in (1, 2)
|
||||
|
||||
partshape = list(range(n_dims))
|
||||
for i in range(n_dims):
|
||||
b = fin.read(4)
|
||||
partshape[i] = struct.unpack("i", b)[0]
|
||||
partshape = list(reversed(partshape))
|
||||
|
||||
name = fin.read(length)
|
||||
data = fin.read(ggml_nbytes(partshape, ftype))
|
||||
|
||||
blck_size = GGML_BLCK_SIZE[WTYPES[ftype]]
|
||||
type_size = GGML_TYPE_SIZE[WTYPES[ftype]]
|
||||
|
||||
print(f"Processing tensor {name} with shape: {partshape} and type: {WTYPE_NAMES[ftype]}")
|
||||
|
||||
# determine dimension along which multipart tensor is sharded
|
||||
#
|
||||
# split_dim 0 regex:
|
||||
# - output.*
|
||||
# - layers.*.attention.wq.weight
|
||||
# - layers.*.attention.wk.weight
|
||||
# - layers.*.attention.wv.weight
|
||||
# - layers.*.feed_forward.w1.weight
|
||||
# - layers.*.feed_forward.w3.weight
|
||||
#
|
||||
# split_dim 1 regex:
|
||||
# - tok_embeddings.*
|
||||
# - layers.*.attention.wo.weight
|
||||
# - layers.*.feed_forward.w2.weight
|
||||
#
|
||||
if n_dims > 1:
|
||||
split_dim = 1
|
||||
if b"tok_embeddings" in name:
|
||||
split_dim = 1
|
||||
elif b"layers" in name:
|
||||
if b"attention.wo.weight" in name:
|
||||
split_dim = 1
|
||||
elif b"feed_forward.w2.weight" in name:
|
||||
split_dim = 1
|
||||
else:
|
||||
split_dim = 0
|
||||
elif b"output" in name:
|
||||
split_dim = 0
|
||||
|
||||
# output tensor header
|
||||
fullshape = list(partshape)
|
||||
if n_dims > 1:
|
||||
fullshape[split_dim] *= n_parts
|
||||
fout.write(struct.pack("iii", n_dims, len(name), ftype))
|
||||
for dim in reversed(fullshape):
|
||||
fout.write(struct.pack("i", dim))
|
||||
fout.write(name)
|
||||
|
||||
# ensure tensor data is aligned
|
||||
tensor_data_offset = fout.tell()
|
||||
while tensor_data_offset % QK != 0:
|
||||
fout.write(struct.pack("B", 0))
|
||||
tensor_data_offset += 1
|
||||
|
||||
# output unified mappable tensor data
|
||||
if n_dims == 1 or n_parts == 1:
|
||||
# copy tensor which we thankfully received in one piece
|
||||
if part_id == 0:
|
||||
fout.write(data)
|
||||
elif split_dim == 0:
|
||||
# reassemble multifile tensor containing some of the rows
|
||||
rows_per_chunk = partshape[0]
|
||||
current_row = part_id * rows_per_chunk
|
||||
bytes_per_row = fullshape[1] // blck_size * type_size
|
||||
offset = current_row * bytes_per_row
|
||||
fout.seek(tensor_data_offset + offset)
|
||||
fout.write(data)
|
||||
elif split_dim == 1:
|
||||
# reassemble multifile tensor containing some of the cols
|
||||
cols_per_chunk = partshape[1]
|
||||
current_col = part_id * cols_per_chunk
|
||||
bpr = partshape[1] // blck_size * type_size
|
||||
bytes_per_row = fullshape[1] // blck_size * type_size
|
||||
offset_current_col = current_col // blck_size * type_size
|
||||
for row in range(partshape[0]):
|
||||
offset_row = row * bytes_per_row
|
||||
offset = offset_row + offset_current_col
|
||||
fout.seek(tensor_data_offset + offset)
|
||||
fout.write(data[row * bpr:row * bpr + bpr])
|
||||
|
||||
# advance file position to next tensor
|
||||
fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype))
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Migrate from GGML to new GGJT file format')
|
||||
parser.add_argument('fin_path', help='your old ggml file (leave out the .1 .2 etc.)')
|
||||
parser.add_argument('fout_path', help='your new ggjt file name')
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
assert args.fin_path
|
||||
assert args.fout_path
|
||||
assert args.fin_path != args.fout_path
|
||||
|
||||
with open(args.fin_path, "rb") as fin:
|
||||
hparams = read_hparams(fin)
|
||||
tokens = read_tokens(fin, hparams)
|
||||
|
||||
if hparams['magic'] == 0x67676a74: # ggjt
|
||||
print("%s: input ggml has already been converted to 'ggjt' magic\n" %
|
||||
(args.fin_path))
|
||||
sys.exit(1)
|
||||
|
||||
if hparams['magic'] != 0x67676d66: # ggmf
|
||||
print("%s: input ggml file doesn't have expected 'ggmf' magic: %#x\n" %
|
||||
(args.fin_path, hparams['magic']))
|
||||
sys.exit(1)
|
||||
|
||||
hparams['magic'] = 0x67676a74 # ggjt
|
||||
|
||||
# count number of multipart files by convention
|
||||
n_parts = 1
|
||||
while True:
|
||||
if os.path.exists("%s.%d" % (args.fin_path, n_parts)):
|
||||
n_parts += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# we output a single file for ggml
|
||||
with open(args.fout_path, "wb") as fout:
|
||||
write_hparams(fout, hparams)
|
||||
write_tokens(fout, tokens)
|
||||
offset_of_tensors = fout.tell()
|
||||
# the tensors we load could be split across multiple files
|
||||
for part_id in range(n_parts):
|
||||
fout.seek(offset_of_tensors)
|
||||
print(f"Processing part {part_id+1} of {n_parts}\n")
|
||||
fin_path = args.fin_path
|
||||
if part_id > 0:
|
||||
fin_path += ".%d" % (part_id)
|
||||
with open(fin_path, "rb") as fin:
|
||||
read_tokens(fin, read_hparams(fin))
|
||||
copy_tensors(fin, fout, part_id, n_parts)
|
||||
|
||||
print(f"Done. Output file: {args.fout_path}\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
BIN
quantize.exe
BIN
quantize.exe
Binary file not shown.
131
quantize.py
131
quantize.py
|
@ -1,131 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""Script to execute the "quantize" script on a given set of models."""
|
||||
|
||||
import subprocess
|
||||
import argparse
|
||||
import glob
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def main():
|
||||
"""Update the quantize binary name depending on the platform and parse
|
||||
the command line arguments and execute the script.
|
||||
"""
|
||||
|
||||
if "linux" in sys.platform or "darwin" in sys.platform:
|
||||
quantize_script_binary = "quantize"
|
||||
|
||||
elif "win32" in sys.platform or "cygwin" in sys.platform:
|
||||
quantize_script_binary = "quantize.exe"
|
||||
|
||||
else:
|
||||
print("WARNING: Unknown platform. Assuming a UNIX-like OS.\n")
|
||||
quantize_script_binary = "quantize"
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='python3 quantize.py',
|
||||
description='This script quantizes the given models by applying the '
|
||||
f'"{quantize_script_binary}" script on them.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'models', nargs='+', choices=('7B', '13B', '30B', '65B'),
|
||||
help='The models to quantize.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-r', '--remove-16', action='store_true', dest='remove_f16',
|
||||
help='Remove the f16 model after quantizing it.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-m', '--models-path', dest='models_path',
|
||||
default=os.path.join(os.getcwd(), "models"),
|
||||
help='Specify the directory where the models are located.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-q', '--quantize-script-path', dest='quantize_script_path',
|
||||
default=os.path.join(os.getcwd(), quantize_script_binary),
|
||||
help='Specify the path to the "quantize" script.'
|
||||
)
|
||||
|
||||
# TODO: Revise this code
|
||||
# parser.add_argument(
|
||||
# '-t', '--threads', dest='threads', type='int',
|
||||
# default=os.cpu_count(),
|
||||
# help='Specify the number of threads to use to quantize many models at '
|
||||
# 'once. Defaults to os.cpu_count().'
|
||||
# )
|
||||
|
||||
args = parser.parse_args()
|
||||
args.models_path = os.path.abspath(args.models_path)
|
||||
|
||||
if not os.path.isfile(args.quantize_script_path):
|
||||
print(
|
||||
f'The "{quantize_script_binary}" script was not found in the '
|
||||
"current location.\nIf you want to use it from another location, "
|
||||
"set the --quantize-script-path argument from the command line."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
for model in args.models:
|
||||
# The model is separated in various parts
|
||||
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
|
||||
f16_model_path_base = os.path.join(
|
||||
args.models_path, model, "ggml-model-f16.bin"
|
||||
)
|
||||
|
||||
if not os.path.isfile(f16_model_path_base):
|
||||
print(f'The file %s was not found' % f16_model_path_base)
|
||||
sys.exit(1)
|
||||
|
||||
f16_model_parts_paths = map(
|
||||
lambda filename: os.path.join(f16_model_path_base, filename),
|
||||
glob.glob(f"{f16_model_path_base}*")
|
||||
)
|
||||
|
||||
for f16_model_part_path in f16_model_parts_paths:
|
||||
if not os.path.isfile(f16_model_part_path):
|
||||
print(
|
||||
f"The f16 model {os.path.basename(f16_model_part_path)} "
|
||||
f"was not found in {args.models_path}{os.path.sep}{model}"
|
||||
". If you want to use it from another location, set the "
|
||||
"--models-path argument from the command line."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
__run_quantize_script(
|
||||
args.quantize_script_path, f16_model_part_path
|
||||
)
|
||||
|
||||
if args.remove_f16:
|
||||
os.remove(f16_model_part_path)
|
||||
|
||||
|
||||
# This was extracted to a top-level function for parallelization, if
|
||||
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
|
||||
|
||||
def __run_quantize_script(script_path, f16_model_part_path):
|
||||
"""Run the quantize script specifying the path to it and the path to the
|
||||
f16 model to quantize.
|
||||
"""
|
||||
|
||||
new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0")
|
||||
subprocess.run(
|
||||
[script_path, f16_model_part_path, new_quantized_model_path, "2"],
|
||||
check=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
|
||||
except subprocess.CalledProcessError:
|
||||
print("\nAn error ocurred while trying to quantize the models.")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(0)
|
||||
|
||||
else:
|
||||
print("\nSuccesfully quantized all models.")
|
Loading…
Add table
Add a link
Reference in a new issue