Revert "Make loading weights 10-100x faster"

This reverts commit 78ca9838ee.
This commit is contained in:
anzz1 2023-04-02 15:39:28 +03:00
parent 255f019f88
commit 17b98ca9ff
7 changed files with 370 additions and 331 deletions

1
.gitignore vendored
View file

@ -22,7 +22,6 @@ models/*
/result /result
/perplexity /perplexity
/embedding /embedding
/Pipfile
arm_neon.h arm_neon.h
compile_commands.json compile_commands.json

View file

@ -84,11 +84,6 @@ def read_variables(fin):
shape = shape[::-1] shape = shape[::-1]
name = fin.read(name_length).decode() name = fin.read(name_length).decode()
# ensure tensor data is aligned
tensor_data_offset = fin.tell()
tensor_data_offset = (tensor_data_offset + 31) & -32
fin.seek(tensor_data_offset)
if ftype_cur == 2: if ftype_cur == 2:
# 4-bit quantized weights # 4-bit quantized weights
dtype = np.uint8 dtype = np.uint8

View file

@ -72,11 +72,6 @@ def write_header(shape, dst_name, ftype_cur):
fout.write(struct.pack("i" * len(shape), *shape[::-1])) fout.write(struct.pack("i" * len(shape), *shape[::-1]))
fout.write(sname) fout.write(sname)
# ensure tensor data is aligned
tensor_data_offset = fout.tell()
tensor_data_offset = (tensor_data_offset + 31) & -32
fout.seek(tensor_data_offset)
def convert_non_q4(src_name, dst_name): def convert_non_q4(src_name, dst_name):
v = model[src_name] v = model[src_name]
shape = v.shape shape = v.shape

View file

@ -24,57 +24,8 @@ import torch
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
QK = 32
GGML_TYPE_Q4_0 = 0
GGML_TYPE_Q4_1 = 1
GGML_TYPE_I8 = 2
GGML_TYPE_I16 = 3
GGML_TYPE_I32 = 4
GGML_TYPE_F16 = 5
GGML_TYPE_F32 = 6
WTYPES = {
0: GGML_TYPE_F32,
1: GGML_TYPE_F16,
2: GGML_TYPE_Q4_0,
3: GGML_TYPE_Q4_1,
}
GGML_BLCK_SIZE = {
GGML_TYPE_Q4_0: QK,
GGML_TYPE_Q4_1: QK,
GGML_TYPE_I8: 1,
GGML_TYPE_I16: 1,
GGML_TYPE_I32: 1,
GGML_TYPE_F16: 1,
GGML_TYPE_F32: 1,
}
GGML_TYPE_SIZE = {
GGML_TYPE_Q4_0: 4 + QK/2,
GGML_TYPE_Q4_1: 4*2 + QK/2,
GGML_TYPE_I8: 1,
GGML_TYPE_I16: 2,
GGML_TYPE_I32: 4,
GGML_TYPE_F16: 2,
GGML_TYPE_F32: 4,
}
def ggml_nelements(shape):
r = 1
for i in shape:
r *= i
return r
def ggml_nbytes(shape, ftype):
x = ggml_nelements(shape)
t = WTYPES[ftype]
x *= GGML_TYPE_SIZE[t]
x //= GGML_BLCK_SIZE[t]
return x
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
parser.add_argument('dir_model', help='directory containing the model checkpoint') parser.add_argument('dir_model', help='directory containing the model checkpoint')
parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1) parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
@ -82,6 +33,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def get_n_parts(dim): def get_n_parts(dim):
mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8} mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
n_parts = mappings.get(dim) n_parts = mappings.get(dim)
if n_parts is None: if n_parts is None:
@ -92,24 +44,30 @@ def get_n_parts(dim):
return n_parts return n_parts
def load_hparams_and_tokenizer(dir_model): def load_hparams_and_tokenizer(dir_model):
# `dir_model` is something like `models/7B` or `models/7B/`. # `dir_model` is something like `models/7B` or `models/7B/`.
# "tokenizer.model" is expected under model's parent dir. # "tokenizer.model" is expected under model's parent dir.
# When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found. # When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found.
# Let's use the model's parent dir directly. # Let's use the model's parent dir directly.
model_parent_dir = os.path.dirname(os.path.normpath(dir_model)) model_parent_dir = os.path.dirname(os.path.normpath(dir_model))
fname_hparams = f"{dir_model}/params.json" fname_hparams = f"{dir_model}/params.json"
fname_tokenizer = f"{model_parent_dir}/tokenizer.model" fname_tokenizer = f"{model_parent_dir}/tokenizer.model"
with open(fname_hparams, "r") as f: with open(fname_hparams, "r") as f:
hparams = json.load(f) hparams = json.load(f)
print(hparams) print(hparams)
tokenizer = SentencePieceProcessor(fname_tokenizer) tokenizer = SentencePieceProcessor(fname_tokenizer)
hparams.update({"vocab_size": tokenizer.vocab_size()}) hparams.update({"vocab_size": tokenizer.vocab_size()})
return hparams, tokenizer return hparams, tokenizer
def write_header(fout, hparams, ftype): def write_header(fout, hparams, ftype):
keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"] keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
values = [ values = [
0x67676a74, # magic: ggjt in hex 0x67676d66, # magic: ggmf in hex
1, # file version 1, # file version
*[hparams[key] for key in keys], *[hparams[key] for key in keys],
hparams["dim"] // hparams["n_heads"], # rot (obsolete) hparams["dim"] // hparams["n_heads"], # rot (obsolete)
@ -118,6 +76,7 @@ def write_header(fout, hparams, ftype):
fout.write(struct.pack("i" * len(values), *values)) fout.write(struct.pack("i" * len(values), *values))
def write_tokens(fout, tokenizer): def write_tokens(fout, tokenizer):
for i in range(tokenizer.vocab_size()): for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i): if tokenizer.is_unknown(i):
text = " \u2047 ".encode() text = " \u2047 ".encode()
@ -136,141 +95,85 @@ def write_tokens(fout, tokenizer):
fout.write(text) fout.write(text)
fout.write(struct.pack("f", tokenizer.get_score(i))) fout.write(struct.pack("f", tokenizer.get_score(i)))
def process_and_write_variables(fout, model, ftype, part_id, n_parts): def process_and_write_variables(fout, model, ftype):
for name, datao in model.items(): for name, datao in model.items():
if name.endswith("freqs"): if name.endswith("freqs"):
continue continue
# remove dimensions with a single element shape = datao.shape
print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")
data = datao.numpy().squeeze() data = datao.numpy().squeeze()
partshape = data.shape n_dims = len(shape)
n_dims = len(data.shape)
assert n_dims in (1, 2)
print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}") # default type is fp16
# coerce single-dimensional tensors from float16 to float32
ftype_cur = 1 ftype_cur = 1
if ftype == 0 or n_dims == 1: if ftype == 0 or n_dims == 1:
print(" Converting to float32") print(" Converting to float32")
data = data.astype(np.float32) data = data.astype(np.float32)
ftype_cur = 0 ftype_cur = 0
blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]]
type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]]
# determine dimension along which multipart tensor is sharded # header
# sname = name.encode('utf-8')
# split_dim 0 regex: fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
# - output.* for dim in reversed(data.shape):
# - layers.*.attention.wq.weight
# - layers.*.attention.wk.weight
# - layers.*.attention.wv.weight
# - layers.*.feed_forward.w1.weight
# - layers.*.feed_forward.w3.weight
#
# split_dim 1 regex:
# - tok_embeddings.*
# - layers.*.attention.wo.weight
# - layers.*.feed_forward.w2.weight
#
if n_dims > 1:
split_dim = 1
if "tok_embeddings" in name:
split_dim = 1
elif "layers" in name:
if "attention.wo.weight" in name:
split_dim = 1
elif "feed_forward.w2.weight" in name:
split_dim = 1
else:
split_dim = 0
elif "output" in name:
split_dim = 0
# output tensor header
fullshape = list(partshape)
if n_dims > 1:
fullshape[split_dim] *= n_parts
sname = name.encode()
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
for dim in reversed(fullshape):
fout.write(struct.pack("i", dim)) fout.write(struct.pack("i", dim))
fout.write(sname) fout.write(sname)
# ensure tensor data is aligned # data output to file
tensor_data_offset = fout.tell() data.tofile(fout)
while tensor_data_offset % QK != 0:
fout.write(struct.pack("B", 0))
tensor_data_offset += 1
# output unified mappable tensor data
if n_dims == 1 or n_parts == 1:
# copy tensor which we thankfully received in one piece
if part_id == 0:
data.tofile(fout)
elif split_dim == 0:
# reassemble multifile tensor containing some of the rows
rows_per_chunk = partshape[0]
current_row = part_id * rows_per_chunk
bytes_per_row = fullshape[1] // blck_size * type_size
offset = current_row * bytes_per_row
fout.seek(tensor_data_offset + offset)
data.tofile(fout)
elif split_dim == 1:
# reassemble multifile tensor containing some of the cols
cols_per_chunk = partshape[1]
current_col = part_id * cols_per_chunk
bytes_per_row = fullshape[1] // blck_size * type_size
offset_current_col = current_col // blck_size * type_size
for row in range(partshape[0]):
offset_row = row * bytes_per_row
offset = offset_row + offset_current_col
fout.seek(tensor_data_offset + offset)
data[row].tofile(fout)
# advance file position to next tensor
fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur))
def main(): def main():
args = parse_args() args = parse_args()
dir_model = args.dir_model dir_model = args.dir_model
ftype = args.ftype ftype = args.ftype
ftype_str = ["f32", "f16"] ftype_str = ["f32", "f16"]
hparams, tokenizer = load_hparams_and_tokenizer(dir_model) hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
print(args) print(args)
# if only writing vocab to file # if only writing vocab to file
if args.vocab_only: if args.vocab_only:
fname_model = f"{dir_model}/consolidated.00.pth" fname_model = f"{dir_model}/consolidated.00.pth"
fname_out = f"{dir_model}/ggml-vocab.bin" fname_out = f"{dir_model}/ggml-vocab.bin"
print(f"Extracting only the vocab from '{fname_model}'\n") print(f"Extracting only the vocab from '{fname_model}'\n")
model = torch.load(fname_model, map_location="cpu")
with open(fname_out, "wb") as fout: with open(fname_out, "wb") as fout:
write_header(fout, hparams, ftype) write_header(fout, hparams, ftype)
write_tokens(fout, tokenizer) write_tokens(fout, tokenizer)
del model
print(f"Done. Output file: {fname_out}\n") print(f"Done. Output file: {fname_out}\n")
return return
n_parts = get_n_parts(hparams["dim"]) n_parts = get_n_parts(hparams["dim"])
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin"
# we output a single file for ggml for p in range(n_parts):
with open(fname_out, "wb") as fout:
write_header(fout, hparams, ftype)
write_tokens(fout, tokenizer)
offset_of_tensors = fout.tell()
# the tensors we load could be split across multiple files
for part_id in range(n_parts):
fout.seek(offset_of_tensors)
print(f"Processing part {part_id+1} of {n_parts}\n")
fname_model = f"{dir_model}/consolidated.0{part_id}.pth"
model = torch.load(fname_model, map_location="cpu")
process_and_write_variables(fout, model, ftype, part_id, n_parts)
del model
print(f"Done. Output file: {fname_out}\n") print(f"Processing part {p+1} of {n_parts}\n")
fname_model = f"{dir_model}/consolidated.0{p}.pth"
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
model = torch.load(fname_model, map_location="cpu")
with open(fname_out, "wb") as fout:
write_header(fout, hparams, ftype)
write_tokens(fout, tokenizer)
process_and_write_variables(fout, model, ftype)
del model
print(f"Done. Output file: {fname_out}, (part {p})\n")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

491
llama.cpp
View file

@ -12,19 +12,17 @@
#include <cassert> #include <cassert>
#include <cstring> #include <cstring>
#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES) // mmap
#define WIN32_LEAN_AND_MEAN #if defined (__unix__) || defined (__APPLE__)
#include <Windows.h> # include <sys/mman.h>
#else # include <fcntl.h>
#include <sys/types.h> # include <unistd.h>
#include <sys/mman.h> #elif defined(_WIN32)
#include <unistd.h> # define WIN32_LEAN_AND_MEAN
#include <fcntl.h> # include <Windows.h>
//#include <Memoryapi.h>
#endif #endif
#define Min(X, Y) ((Y) > (X) ? (X) : (Y))
#define Max(X, Y) ((Y) < (X) ? (X) : (Y))
#define LLAMA_USE_SCRATCH #define LLAMA_USE_SCRATCH
#define LLAMA_MAX_SCRATCH_BUFFERS 16 #define LLAMA_MAX_SCRATCH_BUFFERS 16
@ -157,7 +155,7 @@ struct llama_model {
// model memory mapped file // model memory mapped file
void * mm_addr = NULL; void * mm_addr = NULL;
uint64_t mm_length = 0; size_t mm_length = 0;
// tensors // tensors
int n_loaded; int n_loaded;
@ -182,7 +180,6 @@ struct llama_context {
int64_t t_load_us = 0; int64_t t_load_us = 0;
int64_t t_start_us = 0; int64_t t_start_us = 0;
bool has_evaluated_once = false;
int64_t t_sample_us = 0; int64_t t_sample_us = 0;
int64_t t_eval_us = 0; int64_t t_eval_us = 0;
@ -224,7 +221,7 @@ struct llama_context {
} }
if (buf_last >= 0) { if (buf_last >= 0) {
buf_max_size[buf_last] = Max(buf_max_size[buf_last], last_size); buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
} }
buf_last = i; buf_last = i;
@ -307,57 +304,59 @@ struct llama_context_params llama_context_default_params() {
// model loading // model loading
// //
static void *mmap_file(const char *fname, uint64_t *mm_length) { static void mmap_file(const char* fname, void * &mm_addr, size_t &mm_length) {
#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES) #if defined(MAP_FAILED)
HANDLE hFile = CreateFileA(fname, // POSIX
GENERIC_READ, int fd = open(fname, O_RDONLY);
FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, mm_length = lseek(fd, 0, SEEK_END);
NULL, mm_addr = mmap(NULL, mm_length, PROT_READ, MAP_SHARED, fd, 0);
OPEN_EXISTING, close(fd);
FILE_ATTRIBUTE_NORMAL | FILE_ATTRIBUTE_NOT_CONTENT_INDEXED, if (mm_addr == MAP_FAILED) {
NULL); perror("mmap failed");
if (hFile == INVALID_HANDLE_VALUE) return 0; mm_addr = NULL;
mm_length = 0;
}
#elif defined(_WIN32)
mm_addr = NULL;
HANDLE hFile = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (hFile == INVALID_HANDLE_VALUE) {
return;
}
// not really necessary
LARGE_INTEGER fileSize; LARGE_INTEGER fileSize;
fileSize.QuadPart = -1;
GetFileSizeEx(hFile, &fileSize); GetFileSizeEx(hFile, &fileSize);
int64_t length = fileSize.QuadPart; mm_length = fileSize;
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
CloseHandle(hFile); CloseHandle(hFile);
if (!hMapping) return 0;
void *addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); if (hMapping == NULL) {
return;
}
mm_addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
CloseHandle(hMapping); CloseHandle(hMapping);
if (!addr) return 0;
#else #else
int fd = open(fname, O_RDONLY); mm_addr = NULL;
if (fd == -1) return 0; mm_length = 0;
int64_t length = lseek(fd, 0, SEEK_END); (void)(fname); // suppress warnings
void *addr = mmap(NULL, length, PROT_READ, MAP_SHARED, fd, 0);
close(fd);
if (addr == MAP_FAILED) return 0;
#endif #endif
*mm_length = length;
return addr;
} }
static void munmap_file(void * addr, size_t length) { static void munmap_file(void * addr, size_t length) {
#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES) #if defined(MAP_FAILED)
// POSIX
munmap(addr, length);
#elif defined(_WIN32)
UnmapViewOfFile(addr); UnmapViewOfFile(addr);
#else #else
munmap(addr, length); (void)(addr); // suppress warnings
(void)(length);
#endif #endif
} }
static bool report_bad_magic(const char *path) {
fprintf(stderr,
"%s: invalid model file (bad magic)\n"
"you most likely need to regenerate your ggml files\n"
"the benefit is you'll get 10-100x faster load times\n"
"see https://github.com/ggerganov/llama.cpp/issues/91\n"
"use convert-pth-to-ggml.py on your llama model files\n",
path);
return false;
}
static bool llama_model_load( static bool llama_model_load(
const std::string & fname, const std::string & fname,
llama_context & lctx, llama_context & lctx,
@ -369,24 +368,23 @@ static bool llama_model_load(
void *progress_callback_user_data) { void *progress_callback_user_data) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
lctx.t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
lctx.t_start_us = t_start_us;
// TODO: this could probably be smaller when using mmap
std::vector<char> f_buf(1024*1024);
auto & model = lctx.model; auto & model = lctx.model;
auto & vocab = lctx.vocab; auto & vocab = lctx.vocab;
auto fin = std::ifstream(fname, std::ios::binary); auto fin = std::ifstream(fname, std::ios::binary);
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
if (!fin) { if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false; return false;
} }
std::vector<char> f_buf(1024*1024);
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
fin.seekg(0, fin.end);
const size_t file_size = fin.tellg();
fin.seekg(0);
// verify magic // verify magic
{ {
uint32_t magic; uint32_t magic;
@ -397,7 +395,8 @@ static bool llama_model_load(
return false; return false;
} }
if (magic != LLAMA_FILE_MAGIC) { if (magic != LLAMA_FILE_MAGIC) {
return report_bad_magic(fname.c_str()); fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
} }
uint32_t format_version; uint32_t format_version;
@ -520,24 +519,54 @@ static bool llama_model_load(
} }
} }
// map model into memory bool use_mmap = (n_parts == 1);
char *mm_addr = NULL;
model.mm_addr = mmap_file(fname.c_str(), &model.mm_length); // try to memory map the model file
if (model.mm_addr == NULL) { void * mm_addr = NULL;
fprintf(stderr, "%s: failed to mmap '%s'\n", __func__, fname.c_str()); if (use_mmap) {
return false; mmap_file(fname.c_str(), model.mm_addr, model.mm_length);
if (model.mm_addr == NULL) {
use_mmap = false;
}
else {
mm_addr = model.mm_addr;
}
} }
mm_addr = (char *)model.mm_addr;
fprintf(stderr, "%s: ggml map size = %6.2f MB\n", __func__, model.mm_length/(1024.0*1024.0));
auto & ctx = model.ctx; auto & ctx = model.ctx;
size_t ctx_size = 0; size_t ctx_size = 0;
{ {
const auto &hparams = model.hparams; const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer; const int n_layer = hparams.n_layer;
const int n_vocab = hparams.n_vocab;
if (!use_mmap) {
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
}
ctx_size += (5 + 10*n_layer)*256; // object overhead ctx_size += (5 + 10*n_layer)*256; // object overhead
fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0);
fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
} }
// print memory requirements // print memory requirements
@ -547,7 +576,6 @@ static bool llama_model_load(
// this is the total memory required to run the inference // this is the total memory required to run the inference
const size_t mem_required = const size_t mem_required =
ctx_size + ctx_size +
model.mm_length +
MEM_REQ_SCRATCH0.at(model.type) + MEM_REQ_SCRATCH0.at(model.type) +
MEM_REQ_SCRATCH1.at(model.type) + MEM_REQ_SCRATCH1.at(model.type) +
MEM_REQ_EVAL.at (model.type); MEM_REQ_EVAL.at (model.type);
@ -567,7 +595,7 @@ static bool llama_model_load(
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ lctx.model.buf.size(), /*.mem_size =*/ lctx.model.buf.size(),
/*.mem_buffer =*/ lctx.model.buf.data(), /*.mem_buffer =*/ lctx.model.buf.data(),
/*.no_alloc =*/ true, /*.no_alloc =*/ use_mmap,
}; };
model.ctx = ggml_init(params); model.ctx = ggml_init(params);
@ -630,106 +658,241 @@ static bool llama_model_load(
} }
} }
const size_t file_offset = fin.tellg();
fin.close();
std::vector<uint8_t> tmp; std::vector<uint8_t> tmp;
if (progress_callback) { if (progress_callback) {
progress_callback(0.0, progress_callback_user_data); progress_callback(0.0, progress_callback_user_data);
} }
fprintf(stderr, "%s: loading tensors from '%s'\n", __func__, fname.c_str()); for (int i = 0; i < n_parts; ++i) {
const int part_id = i;
//const int part_id = n_parts - i - 1;
// load weights std::string fname_part = fname;
{ if (i > 0) {
size_t total_size = 0; fname_part += "." + std::to_string(i);
model.n_loaded = 0; }
while (true) { fprintf(stderr, "%s: loading model part %d/%d from '%s'%s\n", __func__, i+1, n_parts, fname_part.c_str(), use_mmap ? " (memory mapped)" : "");
int32_t n_dims;
int32_t length;
int32_t ftype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims)); fin = std::ifstream(fname_part, std::ios::binary);
fin.read(reinterpret_cast<char *>(&length), sizeof(length)); fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
if (fin.eof()) { fin.seekg(0, fin.end);
break; const size_t file_size = fin.tellg();
}
int32_t nelements = 1; fin.seekg(file_offset);
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0); // load weights
fin.read(&name[0], length); {
size_t total_size = 0;
if (model.tensors.find(name.data()) == model.tensors.end()) { model.n_loaded = 0;
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()]; fprintf(stderr, "%s: ", __func__);
if (ggml_nelements(tensor) != nelements) { while (true) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); int32_t n_dims;
return false; int32_t length;
} int32_t ftype;
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%" PRId64 ", %" PRId64 "], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return false;
}
if (0) {
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
fprintf(stderr, "%24s - [%5d, %5d], type = %6s\n", name.data(), ne[0], ne[1], ftype_str[ftype]);
}
switch (ftype) { fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
case 0: // f32 fin.read(reinterpret_cast<char *>(&length), sizeof(length));
case 1: // f16 fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
if (fin.eof()) {
break; break;
case 2: // q4_0 }
case 3: // q4_1
assert(ne[0] % 64 == 0); int32_t nelements = 1;
break; int32_t ne[2] = { 1, 1 };
default: for (int i = 0; i < n_dims; ++i) {
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false; return false;
}; }
// load the tensor data into memory without copying or reading it // split_type = 0: split by columns
size_t offset = fin.tellg(); // split_type = 1: split by rows
size_t tensor_data_size = ggml_nbytes(tensor); int split_type = 0;
offset = (offset + 31) & -32;
tensor->data = mm_addr + offset;
fin.seekg(offset + tensor_data_size);
total_size += tensor_data_size;
model.n_loaded++;
// progress // split_type = 0:
if (progress_callback) { // regex:
double current_progress = size_t(fin.tellg()) / double(file_size); // - tok_embeddings.*
progress_callback(current_progress, progress_callback_user_data); // - layers.*.attention.wo.weight
// - layers.*.feed_forward.w2.weight
// split_type = 1:
// regex:
// - output.*
// - layers.*.attention.wq.weight
// - layers.*.attention.wk.weight
// - layers.*.attention.wv.weight
// - layers.*.feed_forward.w1.weight
// - layers.*.feed_forward.w3.weight
if (name.find("tok_embeddings") != std::string::npos) {
split_type = 0;
} else if (name.find("layers") != std::string::npos) {
if (name.find("attention.wo.weight") != std::string::npos) {
split_type = 0;
} else if (name.find("feed_forward.w2.weight") != std::string::npos) {
split_type = 0;
} else {
split_type = 1;
}
} else if (name.find("output") != std::string::npos) {
split_type = 1;
}
auto tensor = model.tensors[name.data()];
if (n_dims == 1) {
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
} else {
if (ggml_nelements(tensor)/n_parts != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
}
if (n_dims == 1) {
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return false;
}
} else {
if (split_type == 0) {
if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
return false;
}
} else {
if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
return false;
}
}
}
if (0) {
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
fprintf(stderr, "%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
}
size_t bpe = 0;
switch (ftype) {
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
default:
{
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
return false;
}
};
if (n_dims == 1 || n_parts == 1) {
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
if (part_id == 0) {
if (mm_addr) {
off_t offset = fin.tellg();
tensor->data = (char *) mm_addr + offset;
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
}
else {
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
}
} else {
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
}
total_size += ggml_nbytes(tensor);
} else {
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
return false;
}
if (split_type == 0) {
const int np0 = ne[0];
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
assert(row_size == tensor->nb[1]);
for (int i1 = 0; i1 < ne[1]; ++i1) {
const size_t offset_row = i1*row_size;
const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
}
} else {
const int np1 = ne[1];
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
for (int i1 = 0; i1 < ne[1]; ++i1) {
const size_t offset_row = (i1 + part_id*np1)*row_size;
fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
}
}
total_size += ggml_nbytes(tensor)/n_parts;
}
//fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
model.n_loaded++;
// progress
if (progress_callback) {
float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset);
float current_progress = (float(i) + current_file_progress) / float(n_parts);
progress_callback(current_progress, progress_callback_user_data);
}
if (model.n_loaded % 8 == 0) {
fprintf(stderr, ".");
fflush(stderr);
}
}
fprintf(stderr, " done\n");
fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
if (model.n_loaded == 0) {
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
} else if (model.n_loaded != (int) model.tensors.size()) {
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
return false;
} }
} }
fin.close(); fin.close();
fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
if (model.n_loaded == 0) {
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
} else if (model.n_loaded != (int) model.tensors.size()) {
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
return false;
}
} }
// loading time will be recalculate after the first eval, so lctx.t_load_us = ggml_time_us() - t_start_us;
// we take page faults deferred by mmap() into consideration
lctx.t_load_us = ggml_time_us() - lctx.t_start_us;
if (progress_callback) { if (progress_callback) {
progress_callback(1.0, progress_callback_user_data); progress_callback(1.0, progress_callback_user_data);
@ -1053,7 +1216,7 @@ struct llama_tokenizer {
size_t offs = 0; size_t offs = 0;
while (offs < text.size()) { while (offs < text.size()) {
llama_sp_symbol sym; llama_sp_symbol sym;
size_t char_len = Min(text.size() - offs, utf8_len(text[offs])); size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
sym.text = text.c_str() + offs; sym.text = text.c_str() + offs;
sym.n = char_len; sym.n = char_len;
offs += char_len; offs += char_len;
@ -1218,7 +1381,7 @@ static llama_vocab::id llama_sample_top_p_top_k(
float maxl = -std::numeric_limits<float>::infinity(); float maxl = -std::numeric_limits<float>::infinity();
for (const auto & kv : logits_id) { for (const auto & kv : logits_id) {
maxl = Max(maxl, kv.first); maxl = std::max(maxl, kv.first);
} }
// compute probs for the top k tokens // compute probs for the top k tokens
@ -1312,7 +1475,8 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
return false; return false;
} }
if (magic != LLAMA_FILE_MAGIC) { if (magic != LLAMA_FILE_MAGIC) {
return report_bad_magic(fname_inp.c_str()); fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
return false;
} }
fout.write((char *) &magic, sizeof(magic)); fout.write((char *) &magic, sizeof(magic));
@ -1378,8 +1542,8 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
fout.write((char *) &len, sizeof(len)); fout.write((char *) &len, sizeof(len));
word.resize(len); word.resize(len);
finp.read ((char *) &word[0], len); finp.read ((char *) word.data(), len);
fout.write((char *) &word[0], len); fout.write((char *) word.data(), len);
float score; float score;
finp.read ((char *) &score, sizeof(score)); finp.read ((char *) &score, sizeof(score));
@ -1429,13 +1593,6 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
std::string name(length, 0); std::string name(length, 0);
finp.read (&name[0], length); finp.read (&name[0], length);
{
// ensure tensor data is aligned
uint64_t offset = finp.tellg();
offset = (offset + 31) & -32;
finp.seekg(offset);
}
{ {
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]); printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
@ -1491,13 +1648,6 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
} }
fout.write(&name[0], length); fout.write(&name[0], length);
{
// ensure tensor data is aligned
uint64_t offset = fout.tellp();
offset = (offset + 31) & -32;
fout.seekp(offset);
}
if (quantize) { if (quantize) {
printf("quantizing .. "); printf("quantizing .. ");
work.resize(nelements); // for quantization work.resize(nelements); // for quantization
@ -1701,11 +1851,7 @@ int llama_eval(
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return 1; return 1;
} }
// get a more accurate load time, upon first eval
if (!ctx->has_evaluated_once) {
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
ctx->has_evaluated_once = true;
}
return 0; return 0;
} }
@ -1798,9 +1944,9 @@ llama_token llama_sample_top_p_top_k(
void llama_print_timings(struct llama_context * ctx) { void llama_print_timings(struct llama_context * ctx) {
const int64_t t_end_us = ggml_time_us(); const int64_t t_end_us = ggml_time_us();
const int32_t n_sample = Max(1, ctx->n_sample); const int32_t n_sample = std::max(1, ctx->n_sample);
const int32_t n_eval = Max(1, ctx->n_eval); const int32_t n_eval = std::max(1, ctx->n_eval);
const int32_t n_p_eval = Max(1, ctx->n_p_eval); const int32_t n_p_eval = std::max(1, ctx->n_p_eval);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0); fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
@ -1812,6 +1958,7 @@ void llama_print_timings(struct llama_context * ctx) {
void llama_reset_timings(struct llama_context * ctx) { void llama_reset_timings(struct llama_context * ctx) {
ctx->t_start_us = ggml_time_us(); ctx->t_start_us = ggml_time_us();
ctx->t_sample_us = ctx->n_sample = 0; ctx->t_sample_us = ctx->n_sample = 0;
ctx->t_eval_us = ctx->n_eval = 0; ctx->t_eval_us = ctx->n_eval = 0;
ctx->t_p_eval_us = ctx->n_p_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0;

View file

@ -20,7 +20,7 @@
#endif #endif
#define LLAMA_FILE_VERSION 1 #define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex #define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files #define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
#ifdef __cplusplus #ifdef __cplusplus

Binary file not shown.