Revert "Make loading weights 10-100x faster"
This reverts commit 78ca9838ee
.
This commit is contained in:
parent
255f019f88
commit
17b98ca9ff
7 changed files with 370 additions and 331 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -22,7 +22,6 @@ models/*
|
|||
/result
|
||||
/perplexity
|
||||
/embedding
|
||||
/Pipfile
|
||||
|
||||
arm_neon.h
|
||||
compile_commands.json
|
||||
|
|
|
@ -84,11 +84,6 @@ def read_variables(fin):
|
|||
shape = shape[::-1]
|
||||
name = fin.read(name_length).decode()
|
||||
|
||||
# ensure tensor data is aligned
|
||||
tensor_data_offset = fin.tell()
|
||||
tensor_data_offset = (tensor_data_offset + 31) & -32
|
||||
fin.seek(tensor_data_offset)
|
||||
|
||||
if ftype_cur == 2:
|
||||
# 4-bit quantized weights
|
||||
dtype = np.uint8
|
||||
|
|
|
@ -72,11 +72,6 @@ def write_header(shape, dst_name, ftype_cur):
|
|||
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
|
||||
fout.write(sname)
|
||||
|
||||
# ensure tensor data is aligned
|
||||
tensor_data_offset = fout.tell()
|
||||
tensor_data_offset = (tensor_data_offset + 31) & -32
|
||||
fout.seek(tensor_data_offset)
|
||||
|
||||
def convert_non_q4(src_name, dst_name):
|
||||
v = model[src_name]
|
||||
shape = v.shape
|
||||
|
|
|
@ -24,57 +24,8 @@ import torch
|
|||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
QK = 32
|
||||
|
||||
GGML_TYPE_Q4_0 = 0
|
||||
GGML_TYPE_Q4_1 = 1
|
||||
GGML_TYPE_I8 = 2
|
||||
GGML_TYPE_I16 = 3
|
||||
GGML_TYPE_I32 = 4
|
||||
GGML_TYPE_F16 = 5
|
||||
GGML_TYPE_F32 = 6
|
||||
|
||||
WTYPES = {
|
||||
0: GGML_TYPE_F32,
|
||||
1: GGML_TYPE_F16,
|
||||
2: GGML_TYPE_Q4_0,
|
||||
3: GGML_TYPE_Q4_1,
|
||||
}
|
||||
|
||||
GGML_BLCK_SIZE = {
|
||||
GGML_TYPE_Q4_0: QK,
|
||||
GGML_TYPE_Q4_1: QK,
|
||||
GGML_TYPE_I8: 1,
|
||||
GGML_TYPE_I16: 1,
|
||||
GGML_TYPE_I32: 1,
|
||||
GGML_TYPE_F16: 1,
|
||||
GGML_TYPE_F32: 1,
|
||||
}
|
||||
|
||||
GGML_TYPE_SIZE = {
|
||||
GGML_TYPE_Q4_0: 4 + QK/2,
|
||||
GGML_TYPE_Q4_1: 4*2 + QK/2,
|
||||
GGML_TYPE_I8: 1,
|
||||
GGML_TYPE_I16: 2,
|
||||
GGML_TYPE_I32: 4,
|
||||
GGML_TYPE_F16: 2,
|
||||
GGML_TYPE_F32: 4,
|
||||
}
|
||||
|
||||
def ggml_nelements(shape):
|
||||
r = 1
|
||||
for i in shape:
|
||||
r *= i
|
||||
return r
|
||||
|
||||
def ggml_nbytes(shape, ftype):
|
||||
x = ggml_nelements(shape)
|
||||
t = WTYPES[ftype]
|
||||
x *= GGML_TYPE_SIZE[t]
|
||||
x //= GGML_BLCK_SIZE[t]
|
||||
return x
|
||||
|
||||
def parse_args():
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
|
||||
parser.add_argument('dir_model', help='directory containing the model checkpoint')
|
||||
parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
|
||||
|
@ -82,6 +33,7 @@ def parse_args():
|
|||
return parser.parse_args()
|
||||
|
||||
def get_n_parts(dim):
|
||||
|
||||
mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
|
||||
n_parts = mappings.get(dim)
|
||||
if n_parts is None:
|
||||
|
@ -92,24 +44,30 @@ def get_n_parts(dim):
|
|||
return n_parts
|
||||
|
||||
def load_hparams_and_tokenizer(dir_model):
|
||||
|
||||
# `dir_model` is something like `models/7B` or `models/7B/`.
|
||||
# "tokenizer.model" is expected under model's parent dir.
|
||||
# When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found.
|
||||
# Let's use the model's parent dir directly.
|
||||
model_parent_dir = os.path.dirname(os.path.normpath(dir_model))
|
||||
|
||||
fname_hparams = f"{dir_model}/params.json"
|
||||
fname_tokenizer = f"{model_parent_dir}/tokenizer.model"
|
||||
|
||||
with open(fname_hparams, "r") as f:
|
||||
hparams = json.load(f)
|
||||
print(hparams)
|
||||
|
||||
tokenizer = SentencePieceProcessor(fname_tokenizer)
|
||||
hparams.update({"vocab_size": tokenizer.vocab_size()})
|
||||
|
||||
return hparams, tokenizer
|
||||
|
||||
def write_header(fout, hparams, ftype):
|
||||
|
||||
keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
|
||||
values = [
|
||||
0x67676a74, # magic: ggjt in hex
|
||||
0x67676d66, # magic: ggmf in hex
|
||||
1, # file version
|
||||
*[hparams[key] for key in keys],
|
||||
hparams["dim"] // hparams["n_heads"], # rot (obsolete)
|
||||
|
@ -118,6 +76,7 @@ def write_header(fout, hparams, ftype):
|
|||
fout.write(struct.pack("i" * len(values), *values))
|
||||
|
||||
def write_tokens(fout, tokenizer):
|
||||
|
||||
for i in range(tokenizer.vocab_size()):
|
||||
if tokenizer.is_unknown(i):
|
||||
text = " \u2047 ".encode()
|
||||
|
@ -136,141 +95,85 @@ def write_tokens(fout, tokenizer):
|
|||
fout.write(text)
|
||||
fout.write(struct.pack("f", tokenizer.get_score(i)))
|
||||
|
||||
def process_and_write_variables(fout, model, ftype, part_id, n_parts):
|
||||
def process_and_write_variables(fout, model, ftype):
|
||||
|
||||
for name, datao in model.items():
|
||||
|
||||
if name.endswith("freqs"):
|
||||
continue
|
||||
|
||||
# remove dimensions with a single element
|
||||
shape = datao.shape
|
||||
|
||||
print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")
|
||||
|
||||
data = datao.numpy().squeeze()
|
||||
partshape = data.shape
|
||||
n_dims = len(data.shape)
|
||||
assert n_dims in (1, 2)
|
||||
n_dims = len(shape)
|
||||
|
||||
print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}")
|
||||
|
||||
# coerce single-dimensional tensors from float16 to float32
|
||||
# default type is fp16
|
||||
ftype_cur = 1
|
||||
if ftype == 0 or n_dims == 1:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]]
|
||||
type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]]
|
||||
|
||||
# determine dimension along which multipart tensor is sharded
|
||||
#
|
||||
# split_dim 0 regex:
|
||||
# - output.*
|
||||
# - layers.*.attention.wq.weight
|
||||
# - layers.*.attention.wk.weight
|
||||
# - layers.*.attention.wv.weight
|
||||
# - layers.*.feed_forward.w1.weight
|
||||
# - layers.*.feed_forward.w3.weight
|
||||
#
|
||||
# split_dim 1 regex:
|
||||
# - tok_embeddings.*
|
||||
# - layers.*.attention.wo.weight
|
||||
# - layers.*.feed_forward.w2.weight
|
||||
#
|
||||
if n_dims > 1:
|
||||
split_dim = 1
|
||||
if "tok_embeddings" in name:
|
||||
split_dim = 1
|
||||
elif "layers" in name:
|
||||
if "attention.wo.weight" in name:
|
||||
split_dim = 1
|
||||
elif "feed_forward.w2.weight" in name:
|
||||
split_dim = 1
|
||||
else:
|
||||
split_dim = 0
|
||||
elif "output" in name:
|
||||
split_dim = 0
|
||||
|
||||
# output tensor header
|
||||
fullshape = list(partshape)
|
||||
if n_dims > 1:
|
||||
fullshape[split_dim] *= n_parts
|
||||
sname = name.encode()
|
||||
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
|
||||
for dim in reversed(fullshape):
|
||||
# header
|
||||
sname = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
|
||||
for dim in reversed(data.shape):
|
||||
fout.write(struct.pack("i", dim))
|
||||
fout.write(sname)
|
||||
|
||||
# ensure tensor data is aligned
|
||||
tensor_data_offset = fout.tell()
|
||||
while tensor_data_offset % QK != 0:
|
||||
fout.write(struct.pack("B", 0))
|
||||
tensor_data_offset += 1
|
||||
|
||||
# output unified mappable tensor data
|
||||
if n_dims == 1 or n_parts == 1:
|
||||
# copy tensor which we thankfully received in one piece
|
||||
if part_id == 0:
|
||||
data.tofile(fout)
|
||||
elif split_dim == 0:
|
||||
# reassemble multifile tensor containing some of the rows
|
||||
rows_per_chunk = partshape[0]
|
||||
current_row = part_id * rows_per_chunk
|
||||
bytes_per_row = fullshape[1] // blck_size * type_size
|
||||
offset = current_row * bytes_per_row
|
||||
fout.seek(tensor_data_offset + offset)
|
||||
data.tofile(fout)
|
||||
elif split_dim == 1:
|
||||
# reassemble multifile tensor containing some of the cols
|
||||
cols_per_chunk = partshape[1]
|
||||
current_col = part_id * cols_per_chunk
|
||||
bytes_per_row = fullshape[1] // blck_size * type_size
|
||||
offset_current_col = current_col // blck_size * type_size
|
||||
for row in range(partshape[0]):
|
||||
offset_row = row * bytes_per_row
|
||||
offset = offset_row + offset_current_col
|
||||
fout.seek(tensor_data_offset + offset)
|
||||
data[row].tofile(fout)
|
||||
|
||||
# advance file position to next tensor
|
||||
fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur))
|
||||
# data output to file
|
||||
data.tofile(fout)
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
dir_model = args.dir_model
|
||||
ftype = args.ftype
|
||||
ftype_str = ["f32", "f16"]
|
||||
|
||||
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
|
||||
|
||||
print(args)
|
||||
|
||||
# if only writing vocab to file
|
||||
if args.vocab_only:
|
||||
|
||||
fname_model = f"{dir_model}/consolidated.00.pth"
|
||||
fname_out = f"{dir_model}/ggml-vocab.bin"
|
||||
|
||||
print(f"Extracting only the vocab from '{fname_model}'\n")
|
||||
model = torch.load(fname_model, map_location="cpu")
|
||||
|
||||
|
||||
with open(fname_out, "wb") as fout:
|
||||
write_header(fout, hparams, ftype)
|
||||
write_tokens(fout, tokenizer)
|
||||
del model
|
||||
|
||||
|
||||
print(f"Done. Output file: {fname_out}\n")
|
||||
|
||||
return
|
||||
|
||||
n_parts = get_n_parts(hparams["dim"])
|
||||
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin"
|
||||
|
||||
# we output a single file for ggml
|
||||
with open(fname_out, "wb") as fout:
|
||||
write_header(fout, hparams, ftype)
|
||||
write_tokens(fout, tokenizer)
|
||||
offset_of_tensors = fout.tell()
|
||||
# the tensors we load could be split across multiple files
|
||||
for part_id in range(n_parts):
|
||||
fout.seek(offset_of_tensors)
|
||||
print(f"Processing part {part_id+1} of {n_parts}\n")
|
||||
fname_model = f"{dir_model}/consolidated.0{part_id}.pth"
|
||||
model = torch.load(fname_model, map_location="cpu")
|
||||
process_and_write_variables(fout, model, ftype, part_id, n_parts)
|
||||
del model
|
||||
for p in range(n_parts):
|
||||
|
||||
print(f"Done. Output file: {fname_out}\n")
|
||||
print(f"Processing part {p+1} of {n_parts}\n")
|
||||
|
||||
fname_model = f"{dir_model}/consolidated.0{p}.pth"
|
||||
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
|
||||
|
||||
model = torch.load(fname_model, map_location="cpu")
|
||||
|
||||
with open(fname_out, "wb") as fout:
|
||||
write_header(fout, hparams, ftype)
|
||||
write_tokens(fout, tokenizer)
|
||||
process_and_write_variables(fout, model, ftype)
|
||||
|
||||
del model
|
||||
|
||||
print(f"Done. Output file: {fname_out}, (part {p})\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
491
llama.cpp
491
llama.cpp
|
@ -12,19 +12,17 @@
|
|||
#include <cassert>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#include <Windows.h>
|
||||
#else
|
||||
#include <sys/types.h>
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
// mmap
|
||||
#if defined (__unix__) || defined (__APPLE__)
|
||||
# include <sys/mman.h>
|
||||
# include <fcntl.h>
|
||||
# include <unistd.h>
|
||||
#elif defined(_WIN32)
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# include <Windows.h>
|
||||
//#include <Memoryapi.h>
|
||||
#endif
|
||||
|
||||
#define Min(X, Y) ((Y) > (X) ? (X) : (Y))
|
||||
#define Max(X, Y) ((Y) < (X) ? (X) : (Y))
|
||||
|
||||
#define LLAMA_USE_SCRATCH
|
||||
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
||||
|
||||
|
@ -157,7 +155,7 @@ struct llama_model {
|
|||
|
||||
// model memory mapped file
|
||||
void * mm_addr = NULL;
|
||||
uint64_t mm_length = 0;
|
||||
size_t mm_length = 0;
|
||||
|
||||
// tensors
|
||||
int n_loaded;
|
||||
|
@ -182,7 +180,6 @@ struct llama_context {
|
|||
|
||||
int64_t t_load_us = 0;
|
||||
int64_t t_start_us = 0;
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_eval_us = 0;
|
||||
|
@ -224,7 +221,7 @@ struct llama_context {
|
|||
}
|
||||
|
||||
if (buf_last >= 0) {
|
||||
buf_max_size[buf_last] = Max(buf_max_size[buf_last], last_size);
|
||||
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
||||
}
|
||||
|
||||
buf_last = i;
|
||||
|
@ -307,57 +304,59 @@ struct llama_context_params llama_context_default_params() {
|
|||
// model loading
|
||||
//
|
||||
|
||||
static void *mmap_file(const char *fname, uint64_t *mm_length) {
|
||||
#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES)
|
||||
HANDLE hFile = CreateFileA(fname,
|
||||
GENERIC_READ,
|
||||
FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
|
||||
NULL,
|
||||
OPEN_EXISTING,
|
||||
FILE_ATTRIBUTE_NORMAL | FILE_ATTRIBUTE_NOT_CONTENT_INDEXED,
|
||||
NULL);
|
||||
if (hFile == INVALID_HANDLE_VALUE) return 0;
|
||||
static void mmap_file(const char* fname, void * &mm_addr, size_t &mm_length) {
|
||||
#if defined(MAP_FAILED)
|
||||
// POSIX
|
||||
int fd = open(fname, O_RDONLY);
|
||||
mm_length = lseek(fd, 0, SEEK_END);
|
||||
mm_addr = mmap(NULL, mm_length, PROT_READ, MAP_SHARED, fd, 0);
|
||||
close(fd);
|
||||
if (mm_addr == MAP_FAILED) {
|
||||
perror("mmap failed");
|
||||
mm_addr = NULL;
|
||||
mm_length = 0;
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
mm_addr = NULL;
|
||||
|
||||
HANDLE hFile = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
|
||||
if (hFile == INVALID_HANDLE_VALUE) {
|
||||
return;
|
||||
}
|
||||
|
||||
// not really necessary
|
||||
LARGE_INTEGER fileSize;
|
||||
fileSize.QuadPart = -1;
|
||||
GetFileSizeEx(hFile, &fileSize);
|
||||
int64_t length = fileSize.QuadPart;
|
||||
mm_length = fileSize;
|
||||
|
||||
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
CloseHandle(hFile);
|
||||
if (!hMapping) return 0;
|
||||
void *addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||
|
||||
if (hMapping == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
mm_addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||
CloseHandle(hMapping);
|
||||
if (!addr) return 0;
|
||||
#else
|
||||
int fd = open(fname, O_RDONLY);
|
||||
if (fd == -1) return 0;
|
||||
int64_t length = lseek(fd, 0, SEEK_END);
|
||||
void *addr = mmap(NULL, length, PROT_READ, MAP_SHARED, fd, 0);
|
||||
close(fd);
|
||||
if (addr == MAP_FAILED) return 0;
|
||||
mm_addr = NULL;
|
||||
mm_length = 0;
|
||||
(void)(fname); // suppress warnings
|
||||
#endif
|
||||
*mm_length = length;
|
||||
return addr;
|
||||
}
|
||||
|
||||
static void munmap_file(void * addr, size_t length) {
|
||||
#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES)
|
||||
#if defined(MAP_FAILED)
|
||||
// POSIX
|
||||
munmap(addr, length);
|
||||
#elif defined(_WIN32)
|
||||
UnmapViewOfFile(addr);
|
||||
#else
|
||||
munmap(addr, length);
|
||||
(void)(addr); // suppress warnings
|
||||
(void)(length);
|
||||
#endif
|
||||
}
|
||||
|
||||
static bool report_bad_magic(const char *path) {
|
||||
fprintf(stderr,
|
||||
"%s: invalid model file (bad magic)\n"
|
||||
"you most likely need to regenerate your ggml files\n"
|
||||
"the benefit is you'll get 10-100x faster load times\n"
|
||||
"see https://github.com/ggerganov/llama.cpp/issues/91\n"
|
||||
"use convert-pth-to-ggml.py on your llama model files\n",
|
||||
path);
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_model_load(
|
||||
const std::string & fname,
|
||||
llama_context & lctx,
|
||||
|
@ -369,24 +368,23 @@ static bool llama_model_load(
|
|||
void *progress_callback_user_data) {
|
||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
lctx.t_start_us = ggml_time_us();
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
lctx.t_start_us = t_start_us;
|
||||
|
||||
// TODO: this could probably be smaller when using mmap
|
||||
std::vector<char> f_buf(1024*1024);
|
||||
|
||||
auto & model = lctx.model;
|
||||
auto & vocab = lctx.vocab;
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<char> f_buf(1024*1024);
|
||||
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
||||
|
||||
fin.seekg(0, fin.end);
|
||||
const size_t file_size = fin.tellg();
|
||||
fin.seekg(0);
|
||||
|
||||
// verify magic
|
||||
{
|
||||
uint32_t magic;
|
||||
|
@ -397,7 +395,8 @@ static bool llama_model_load(
|
|||
return false;
|
||||
}
|
||||
if (magic != LLAMA_FILE_MAGIC) {
|
||||
return report_bad_magic(fname.c_str());
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t format_version;
|
||||
|
@ -520,24 +519,54 @@ static bool llama_model_load(
|
|||
}
|
||||
}
|
||||
|
||||
// map model into memory
|
||||
char *mm_addr = NULL;
|
||||
model.mm_addr = mmap_file(fname.c_str(), &model.mm_length);
|
||||
if (model.mm_addr == NULL) {
|
||||
fprintf(stderr, "%s: failed to mmap '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
bool use_mmap = (n_parts == 1);
|
||||
|
||||
// try to memory map the model file
|
||||
void * mm_addr = NULL;
|
||||
if (use_mmap) {
|
||||
mmap_file(fname.c_str(), model.mm_addr, model.mm_length);
|
||||
if (model.mm_addr == NULL) {
|
||||
use_mmap = false;
|
||||
}
|
||||
else {
|
||||
mm_addr = model.mm_addr;
|
||||
}
|
||||
}
|
||||
mm_addr = (char *)model.mm_addr;
|
||||
fprintf(stderr, "%s: ggml map size = %6.2f MB\n", __func__, model.mm_length/(1024.0*1024.0));
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
{
|
||||
const auto &hparams = model.hparams;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
if (!use_mmap) {
|
||||
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings
|
||||
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
|
||||
|
||||
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
|
||||
|
||||
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
|
||||
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
|
||||
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
|
||||
}
|
||||
|
||||
ctx_size += (5 + 10*n_layer)*256; // object overhead
|
||||
fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0);
|
||||
|
||||
fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
|
||||
// print memory requirements
|
||||
|
@ -547,7 +576,6 @@ static bool llama_model_load(
|
|||
// this is the total memory required to run the inference
|
||||
const size_t mem_required =
|
||||
ctx_size +
|
||||
model.mm_length +
|
||||
MEM_REQ_SCRATCH0.at(model.type) +
|
||||
MEM_REQ_SCRATCH1.at(model.type) +
|
||||
MEM_REQ_EVAL.at (model.type);
|
||||
|
@ -567,7 +595,7 @@ static bool llama_model_load(
|
|||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ lctx.model.buf.size(),
|
||||
/*.mem_buffer =*/ lctx.model.buf.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
/*.no_alloc =*/ use_mmap,
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
|
@ -630,106 +658,241 @@ static bool llama_model_load(
|
|||
}
|
||||
}
|
||||
|
||||
const size_t file_offset = fin.tellg();
|
||||
|
||||
fin.close();
|
||||
|
||||
std::vector<uint8_t> tmp;
|
||||
|
||||
if (progress_callback) {
|
||||
progress_callback(0.0, progress_callback_user_data);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: loading tensors from '%s'\n", __func__, fname.c_str());
|
||||
for (int i = 0; i < n_parts; ++i) {
|
||||
const int part_id = i;
|
||||
//const int part_id = n_parts - i - 1;
|
||||
|
||||
// load weights
|
||||
{
|
||||
size_t total_size = 0;
|
||||
model.n_loaded = 0;
|
||||
std::string fname_part = fname;
|
||||
if (i > 0) {
|
||||
fname_part += "." + std::to_string(i);
|
||||
}
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
fprintf(stderr, "%s: loading model part %d/%d from '%s'%s\n", __func__, i+1, n_parts, fname_part.c_str(), use_mmap ? " (memory mapped)" : "");
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
fin = std::ifstream(fname_part, std::ios::binary);
|
||||
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
fin.seekg(0, fin.end);
|
||||
const size_t file_size = fin.tellg();
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
fin.seekg(file_offset);
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
// load weights
|
||||
{
|
||||
size_t total_size = 0;
|
||||
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
model.n_loaded = 0;
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
fprintf(stderr, "%s: ", __func__);
|
||||
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%" PRId64 ", %" PRId64 "], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
if (0) {
|
||||
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
||||
fprintf(stderr, "%24s - [%5d, %5d], type = %6s\n", name.data(), ne[0], ne[1], ftype_str[ftype]);
|
||||
}
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
|
||||
switch (ftype) {
|
||||
case 0: // f32
|
||||
case 1: // f16
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
case 2: // q4_0
|
||||
case 3: // q4_1
|
||||
assert(ne[0] % 64 == 0);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
};
|
||||
}
|
||||
|
||||
// load the tensor data into memory without copying or reading it
|
||||
size_t offset = fin.tellg();
|
||||
size_t tensor_data_size = ggml_nbytes(tensor);
|
||||
offset = (offset + 31) & -32;
|
||||
tensor->data = mm_addr + offset;
|
||||
fin.seekg(offset + tensor_data_size);
|
||||
total_size += tensor_data_size;
|
||||
model.n_loaded++;
|
||||
// split_type = 0: split by columns
|
||||
// split_type = 1: split by rows
|
||||
int split_type = 0;
|
||||
|
||||
// progress
|
||||
if (progress_callback) {
|
||||
double current_progress = size_t(fin.tellg()) / double(file_size);
|
||||
progress_callback(current_progress, progress_callback_user_data);
|
||||
// split_type = 0:
|
||||
// regex:
|
||||
// - tok_embeddings.*
|
||||
// - layers.*.attention.wo.weight
|
||||
// - layers.*.feed_forward.w2.weight
|
||||
|
||||
// split_type = 1:
|
||||
// regex:
|
||||
// - output.*
|
||||
// - layers.*.attention.wq.weight
|
||||
// - layers.*.attention.wk.weight
|
||||
// - layers.*.attention.wv.weight
|
||||
// - layers.*.feed_forward.w1.weight
|
||||
// - layers.*.feed_forward.w3.weight
|
||||
if (name.find("tok_embeddings") != std::string::npos) {
|
||||
split_type = 0;
|
||||
} else if (name.find("layers") != std::string::npos) {
|
||||
if (name.find("attention.wo.weight") != std::string::npos) {
|
||||
split_type = 0;
|
||||
} else if (name.find("feed_forward.w2.weight") != std::string::npos) {
|
||||
split_type = 0;
|
||||
} else {
|
||||
split_type = 1;
|
||||
}
|
||||
} else if (name.find("output") != std::string::npos) {
|
||||
split_type = 1;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
|
||||
if (n_dims == 1) {
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (ggml_nelements(tensor)/n_parts != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_dims == 1) {
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (split_type == 0) {
|
||||
if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (0) {
|
||||
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
||||
fprintf(stderr, "%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
|
||||
}
|
||||
|
||||
size_t bpe = 0;
|
||||
|
||||
switch (ftype) {
|
||||
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
|
||||
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
|
||||
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
|
||||
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
if (n_dims == 1 || n_parts == 1) {
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (part_id == 0) {
|
||||
if (mm_addr) {
|
||||
off_t offset = fin.tellg();
|
||||
tensor->data = (char *) mm_addr + offset;
|
||||
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
|
||||
}
|
||||
else {
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
}
|
||||
} else {
|
||||
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
|
||||
}
|
||||
|
||||
total_size += ggml_nbytes(tensor);
|
||||
} else {
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (split_type == 0) {
|
||||
const int np0 = ne[0];
|
||||
|
||||
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
|
||||
assert(row_size == tensor->nb[1]);
|
||||
|
||||
for (int i1 = 0; i1 < ne[1]; ++i1) {
|
||||
const size_t offset_row = i1*row_size;
|
||||
const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
|
||||
fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
|
||||
}
|
||||
} else {
|
||||
const int np1 = ne[1];
|
||||
|
||||
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
|
||||
|
||||
for (int i1 = 0; i1 < ne[1]; ++i1) {
|
||||
const size_t offset_row = (i1 + part_id*np1)*row_size;
|
||||
fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
|
||||
}
|
||||
}
|
||||
|
||||
total_size += ggml_nbytes(tensor)/n_parts;
|
||||
}
|
||||
|
||||
//fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
model.n_loaded++;
|
||||
|
||||
// progress
|
||||
if (progress_callback) {
|
||||
float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset);
|
||||
float current_progress = (float(i) + current_file_progress) / float(n_parts);
|
||||
progress_callback(current_progress, progress_callback_user_data);
|
||||
}
|
||||
if (model.n_loaded % 8 == 0) {
|
||||
fprintf(stderr, ".");
|
||||
fflush(stderr);
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, " done\n");
|
||||
|
||||
fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
|
||||
if (model.n_loaded == 0) {
|
||||
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
||||
} else if (model.n_loaded != (int) model.tensors.size()) {
|
||||
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
fin.close();
|
||||
|
||||
fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
|
||||
if (model.n_loaded == 0) {
|
||||
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
||||
} else if (model.n_loaded != (int) model.tensors.size()) {
|
||||
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// loading time will be recalculate after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
lctx.t_load_us = ggml_time_us() - lctx.t_start_us;
|
||||
lctx.t_load_us = ggml_time_us() - t_start_us;
|
||||
|
||||
if (progress_callback) {
|
||||
progress_callback(1.0, progress_callback_user_data);
|
||||
|
@ -1053,7 +1216,7 @@ struct llama_tokenizer {
|
|||
size_t offs = 0;
|
||||
while (offs < text.size()) {
|
||||
llama_sp_symbol sym;
|
||||
size_t char_len = Min(text.size() - offs, utf8_len(text[offs]));
|
||||
size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
|
||||
sym.text = text.c_str() + offs;
|
||||
sym.n = char_len;
|
||||
offs += char_len;
|
||||
|
@ -1218,7 +1381,7 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
|||
|
||||
float maxl = -std::numeric_limits<float>::infinity();
|
||||
for (const auto & kv : logits_id) {
|
||||
maxl = Max(maxl, kv.first);
|
||||
maxl = std::max(maxl, kv.first);
|
||||
}
|
||||
|
||||
// compute probs for the top k tokens
|
||||
|
@ -1312,7 +1475,8 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
return false;
|
||||
}
|
||||
if (magic != LLAMA_FILE_MAGIC) {
|
||||
return report_bad_magic(fname_inp.c_str());
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
fout.write((char *) &magic, sizeof(magic));
|
||||
|
@ -1378,8 +1542,8 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
fout.write((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
finp.read ((char *) &word[0], len);
|
||||
fout.write((char *) &word[0], len);
|
||||
finp.read ((char *) word.data(), len);
|
||||
fout.write((char *) word.data(), len);
|
||||
|
||||
float score;
|
||||
finp.read ((char *) &score, sizeof(score));
|
||||
|
@ -1429,13 +1593,6 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
std::string name(length, 0);
|
||||
finp.read (&name[0], length);
|
||||
|
||||
{
|
||||
// ensure tensor data is aligned
|
||||
uint64_t offset = finp.tellg();
|
||||
offset = (offset + 31) & -32;
|
||||
finp.seekg(offset);
|
||||
}
|
||||
|
||||
{
|
||||
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
||||
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
|
||||
|
@ -1491,13 +1648,6 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
}
|
||||
fout.write(&name[0], length);
|
||||
|
||||
{
|
||||
// ensure tensor data is aligned
|
||||
uint64_t offset = fout.tellp();
|
||||
offset = (offset + 31) & -32;
|
||||
fout.seekp(offset);
|
||||
}
|
||||
|
||||
if (quantize) {
|
||||
printf("quantizing .. ");
|
||||
work.resize(nelements); // for quantization
|
||||
|
@ -1701,11 +1851,7 @@ int llama_eval(
|
|||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
// get a more accurate load time, upon first eval
|
||||
if (!ctx->has_evaluated_once) {
|
||||
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
|
||||
ctx->has_evaluated_once = true;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -1798,9 +1944,9 @@ llama_token llama_sample_top_p_top_k(
|
|||
void llama_print_timings(struct llama_context * ctx) {
|
||||
const int64_t t_end_us = ggml_time_us();
|
||||
|
||||
const int32_t n_sample = Max(1, ctx->n_sample);
|
||||
const int32_t n_eval = Max(1, ctx->n_eval);
|
||||
const int32_t n_p_eval = Max(1, ctx->n_p_eval);
|
||||
const int32_t n_sample = std::max(1, ctx->n_sample);
|
||||
const int32_t n_eval = std::max(1, ctx->n_eval);
|
||||
const int32_t n_p_eval = std::max(1, ctx->n_p_eval);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
|
||||
|
@ -1812,6 +1958,7 @@ void llama_print_timings(struct llama_context * ctx) {
|
|||
|
||||
void llama_reset_timings(struct llama_context * ctx) {
|
||||
ctx->t_start_us = ggml_time_us();
|
||||
|
||||
ctx->t_sample_us = ctx->n_sample = 0;
|
||||
ctx->t_eval_us = ctx->n_eval = 0;
|
||||
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
||||
|
|
2
llama.h
2
llama.h
|
@ -20,7 +20,7 @@
|
|||
#endif
|
||||
|
||||
#define LLAMA_FILE_VERSION 1
|
||||
#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
|
||||
#define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex
|
||||
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue