Merge branch 'master' into concedo

# Conflicts:
#	.github/workflows/build.yml
#	.gitignore
#	CMakeLists.txt
#	Makefile
This commit is contained in:
Concedo 2023-03-29 21:08:03 +08:00
commit 49c4c225b5
28 changed files with 1089 additions and 729 deletions

6
.gitignore vendored
View file

@ -5,6 +5,7 @@
.vscode/
.DS_Store
.build/
build/
build-em/
build-debug/
@ -20,6 +21,7 @@ models/*
/quantize
/result
/perplexity
/embedding
arm_neon.h
compile_commands.json
@ -27,5 +29,9 @@ compile_commands.json
.envrc
.direnv/
.venv
__pycache__
.swiftpm
dist/
llama_for_kobold.spec

View file

@ -35,6 +35,11 @@ CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
LDFLAGS =
# warnings
CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
#lets try enabling everything
CFLAGS += -pthread -mf16c -mfma -mavx2 -mavx -msse3
CXXFLAGS += -pthread

20
Package.swift Normal file
View file

@ -0,0 +1,20 @@
// swift-tools-version:5.3
import PackageDescription
let package = Package(
name: "llama",
products: [
.library(name: "llama", targets: ["llama"]),
],
targets: [
.target(
name: "llama",
path: ".",
sources: ["ggml.c", "llama.cpp"],
publicHeadersPath: "spm-headers",
cSettings: [.unsafeFlags(["-Wno-shorten-64-to-32"])]
),
],
cxxLanguageStandard: .cxx11
)

View file

@ -145,13 +145,11 @@ def main():
print(f"Extracting only the vocab from '{fname_model}'\n")
model = torch.load(fname_model, map_location="cpu")
with open(fname_out, "wb") as fout:
write_header(fout, hparams, ftype)
write_tokens(fout, tokenizer)
del model
print(f"Done. Output file: {fname_out}\n")

View file

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Original by https://github.com/eiz
# https://github.com/ggerganov/llama.cpp/issues/324#issuecomment-1476227818
import argparse
import glob
import os
import struct
import sys
from sentencepiece import SentencePieceProcessor
HPARAMS = keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
def parse_args():
parser = argparse.ArgumentParser(description='Upgrade old ggml model files to the current format')
parser.add_argument('dir_model', help='directory containing ggml .bin files')
parser.add_argument('tokenizer_model', help='path to LLaMA tokenizer.model file')
return parser.parse_args()
def read_header(f_in):
struct_fmt = "i" * (3 + len(HPARAMS))
struct_size = struct.calcsize(struct_fmt)
buf = f_in.read(struct_size)
return struct.unpack(struct_fmt, buf)
def write_header(f_out, header):
(magic, vocab_size, dim, multiple_of, n_heads, n_layers, rot, ftype) = header
if magic != 0x67676d6c:
raise Exception('Invalid file magic. Must be an old style ggml file.')
values = [
0x67676d66, # magic: ggml in hex
1, # file version
vocab_size,
dim,
multiple_of,
n_heads,
n_layers,
rot,
ftype
]
f_out.write(struct.pack("i" * len(values), *values))
def write_tokens(fout, tokenizer):
for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i):
text = " \u2047 ".encode("utf-8")
elif tokenizer.is_control(i):
text = b""
elif tokenizer.is_byte(i):
piece = tokenizer.id_to_piece(i)
if len(piece) != 6:
print(f"Invalid token: {piece}")
sys.exit(1)
byte_value = int(piece[3:-1], 16)
text = struct.pack("B", byte_value)
else:
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", tokenizer.get_score(i)))
def read_tokens(f_in, tokenizer):
for i in range(tokenizer.vocab_size()):
len_b = f_in.read(4)
(length,) = struct.unpack("i", len_b)
f_in.read(length)
def copy_all_data(f_out, f_in):
while True:
buf = f_in.read(1024 * 1024)
if not buf:
break
f_out.write(buf)
def convert_one_file(path_in, tokenizer):
path_tmp = f"{path_in}.tmp"
path_orig= f"{path_in}.orig"
print(f"converting {path_in}")
with open(path_in, "rb") as f_in, open(path_tmp, "wb") as f_out:
write_header(f_out, read_header(f_in))
read_tokens(f_in, tokenizer)
write_tokens(f_out, tokenizer)
copy_all_data(f_out, f_in)
os.rename(path_in, path_orig)
os.rename(path_tmp, path_in)
def main():
args = parse_args()
files = []
files.extend(glob.glob(f"{args.dir_model}/*.bin"))
files.extend(glob.glob(f"{args.dir_model}/*.bin.*"))
tokenizer = SentencePieceProcessor(args.tokenizer_model)
for file in files:
convert_one_file(file, tokenizer)
if __name__ == "__main__":
main()

294
convert_ggml_to_pth.py Normal file
View file

@ -0,0 +1,294 @@
# Author: github.com/ductai199x
import argparse
import os
import struct
import numpy as np
import torch
from numba import njit
from tqdm.auto import tqdm
def read_header(fin):
values = struct.unpack("i" * 9, fin.read(4 * 9))
_, _, vocab_size, dim, multiple_of, n_heads, n_layers, rot, ftype = values
return {
"vocab_size": vocab_size,
"dim": dim,
"multiple_of": multiple_of,
"n_heads": n_heads,
"n_layers": n_layers,
}, ftype
def read_tokens(fin, vocab_size):
tokens = []
for _ in range(vocab_size):
text_len = struct.unpack("i", fin.read(4))[0]
text_bytes = fin.read(text_len)
try:
text = text_bytes.decode("utf-8")
except UnicodeDecodeError:
text = text_bytes.decode("utf-8", "replace")
score = struct.unpack("f", fin.read(4))[0]
tokens.append((text, score))
return tokens
@njit
def dequantize_weights_numba(fin_data, n_rows, n_cols):
qk = 32
nb = n_cols // qk
bs = 4 + (qk // 2)
weights = np.zeros((n_rows, n_cols), dtype=np.float32)
data_pos = 0
for row in range(n_rows):
for block in range(nb):
d = np.frombuffer(fin_data[data_pos : data_pos + 4], dtype=np.float32)[0]
data_pos += 4
packed_values = fin_data[data_pos : data_pos + (qk // 2)]
data_pos += qk // 2
for i in range(qk // 2):
packed_value = packed_values[i]
v0 = np.float32((packed_value & 0b00001111) - 8) * d
v1 = np.float32((packed_value >> 4) - 8) * d
weights[row, block * qk + 2 * i] = v0
weights[row, block * qk + 2 * i + 1] = v1
return weights
def dequantize_weights(fin, n_rows, n_cols):
qk = 32
nb = n_cols // qk
data_size = n_rows * n_cols // 2 + n_rows * nb * 4
fin_data = fin.read(data_size)
return dequantize_weights_numba(fin_data, n_rows, n_cols)
def read_variables(fin):
model = {}
pbar = tqdm(total=os.path.getsize(fin.name), unit="B", unit_scale=True, desc="Reading variables")
while True:
start_pos = fin.tell()
try:
n_dims, name_length, ftype_cur = struct.unpack("iii", fin.read(4 * 3))
except struct.error:
break
shape = tuple(struct.unpack("i" * n_dims, fin.read(4 * n_dims)))
shape = shape[::-1]
name = fin.read(name_length).decode("utf-8")
if ftype_cur == 2:
# 4-bit quantized weights
dtype = np.uint8
data = dequantize_weights(fin, shape[0], shape[1])
data = data.reshape(shape)
elif ftype_cur == 0:
dtype = np.float32
data_size = np.prod(shape)
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
elif ftype_cur == 1:
dtype = np.float16
data_size = np.prod(shape)
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
model[name] = torch.tensor(data, dtype=torch.float32 if dtype == np.float32 else torch.float16)
pbar.update(fin.tell() - start_pos)
return model
def convert_to_hf_format(model, hparams):
# This works for llama 7B, need to test with other models
n_layers = hparams["n_layers"]
n_heads = hparams["n_heads"]
dim = hparams["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
state_dict = {}
for layer_i in range(n_layers):
state_dict.update(
{
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
model[f"layers.{layer_i}.attention.wq.weight"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
model[f"layers.{layer_i}.attention.wk.weight"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": model[
f"layers.{layer_i}.attention.wv.weight"
],
f"model.layers.{layer_i}.self_attn.o_proj.weight": model[
f"layers.{layer_i}.attention.wo.weight"
],
f"model.layers.{layer_i}.mlp.gate_proj.weight": model[
f"layers.{layer_i}.feed_forward.w1.weight"
],
f"model.layers.{layer_i}.mlp.down_proj.weight": model[
f"layers.{layer_i}.feed_forward.w2.weight"
],
f"model.layers.{layer_i}.mlp.up_proj.weight": model[
f"layers.{layer_i}.feed_forward.w3.weight"
],
f"model.layers.{layer_i}.input_layernorm.weight": model[
f"layers.{layer_i}.attention_norm.weight"
],
f"model.layers.{layer_i}.post_attention_layernorm.weight": model[
f"layers.{layer_i}.ffn_norm.weight"
],
}
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
state_dict.update(
{
"model.embed_tokens.weight": model["tok_embeddings.weight"],
"model.norm.weight": model["norm.weight"],
"lm_head.weight": model["output.weight"],
}
)
return state_dict
def chat(model, hparams, llama_dir):
from transformers import (GenerationConfig, LlamaForCausalLM,
LlamaTokenizer, StoppingCriteria,
StoppingCriteriaList)
from transformers.models.llama.configuration_llama import LlamaConfig
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self):
super().__init__()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]):
print(tokenizer.decode(input_ids[0]), end="", flush=True)
if input_ids[0][-1] == 13:
return True
return False
config = LlamaConfig(
vocab_size=hparams["vocab_size"],
dim=hparams["dim"],
num_hidden_layers=hparams["n_layers"],
num_attention_heads=hparams["n_heads"],
)
llama = LlamaForCausalLM(config=config)
llama.load_state_dict(state_dict=model, strict=True)
tokenizer = LlamaTokenizer.from_pretrained(llama_dir)
device = torch.device("cpu")
llama = llama.to(device)
ctx = """You are AI.
This is a dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, respectful, direct, concise, should try to protect User's privacy, and knows its own limits. Also, AI must answer User and AI cannot stop the conversation by itself.
User: Hello, AI.
AI: Hello! How can I assist you today?
"""
print(ctx.rstrip("\n"))
while True:
print("-" * 60)
prompt = input(f"User: ")
if ctx != "":
ctx = ctx + "User: " + prompt + "\n"
else:
ctx = prompt + "\nAI:"
ctx = (ctx[-1920:]) if len(ctx) >= 2048 else ctx
print("-" * 60)
if len(ctx.strip()) > 0:
input_ids = tokenizer(ctx, return_tensors="pt")["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=0.8,
top_p=0.95,
top_k=50,
repetition_penalty=1.1764,
)
with torch.no_grad():
generation_output = llama.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_length=2048,
do_sample=True,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub()]),
)
s = generation_output.sequences[0]
decoded = tokenizer.decode(s)
ctx = decoded + "\n"
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir", "-i", type=str, required=True, help="The input directory containing the ggml files."
)
parser.add_argument(
"--prefix",
"-p",
type=str,
required=True,
help="The prefix of the ggml files (ggml-model-f16 or ggml-model-q4_0).",
)
parser.add_argument(
"--hf",
action="store_true",
help="Whether to save the model in the huggingface format. (default: False)",
)
parser.add_argument(
"--chat", "-c", action="store_true", help="Whether to open a chat with the model. (default: False)"
)
args = parser.parse_args()
llama_dir = os.path.abspath(f"{args.input_dir}/../")
ggml_files = sorted(
[f"{args.input_dir}/{f}" for f in os.listdir(args.input_dir) if f.startswith(args.prefix)]
)
fin = open(ggml_files[0], "rb")
hparams, ftype = read_header(fin)
tokens = read_tokens(fin, hparams["vocab_size"])
model = read_variables(fin)
for f in tqdm(ggml_files[1:]):
fin = open(f, "rb")
read_header(fin)
read_tokens(fin, hparams["vocab_size"])
model.update(read_variables(fin))
if args.hf:
model = convert_to_hf_format(model, hparams)
pth_ckpt = {
"state_dict": model,
"hparams": hparams,
"tokens": tokens,
}
torch.save(pth_ckpt, f"{args.input_dir}/{args.prefix}-to-torch.pth")
if args.chat:
if not args.hf:
model = convert_to_hf_format(model, hparams)
chat(model, hparams, llama_dir)
if __name__ == "__main__":
main()

View file

@ -9,11 +9,20 @@
#include <iterator>
#include <algorithm>
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
#include <alloca.h>
#endif
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
#include <alloca.h>
#endif
#if defined (_WIN32)
#pragma comment(lib,"kernel32.lib")
extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
#endif
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
// determine sensible default number of threads.
@ -204,19 +213,19 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 - infinity)\n", params.n_predict);
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
if (ggml_mlock_supported()) {
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
@ -256,3 +265,47 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
/* Keep track of current color of output, and emit ANSI code if it changes. */
void set_console_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) {
switch(color) {
case CONSOLE_COLOR_DEFAULT:
printf(ANSI_COLOR_RESET);
break;
case CONSOLE_COLOR_PROMPT:
printf(ANSI_COLOR_YELLOW);
break;
case CONSOLE_COLOR_USER_INPUT:
printf(ANSI_BOLD ANSI_COLOR_GREEN);
break;
}
con_st.color = color;
}
}
#if defined (_WIN32)
void win32_console_init(bool enable_color) {
unsigned long dwMode = 0;
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
hConOut = 0;
}
}
if (hConOut) {
// Enable ANSI colors on Windows 10+
if (enable_color && !(dwMode & 0x4)) {
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
}
// Set console output codepage to UTF8
SetConsoleOutputCP(65001); // CP_UTF8
}
void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
// Set console input codepage to UTF8
SetConsoleCP(65001); // CP_UTF8
}
}
#endif

View file

@ -63,3 +63,33 @@ std::string gpt_random_prompt(std::mt19937 & rng);
//
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
//
// Console utils
//
#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_GREEN "\x1b[32m"
#define ANSI_COLOR_YELLOW "\x1b[33m"
#define ANSI_COLOR_BLUE "\x1b[34m"
#define ANSI_COLOR_MAGENTA "\x1b[35m"
#define ANSI_COLOR_CYAN "\x1b[36m"
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m"
enum console_color_t {
CONSOLE_COLOR_DEFAULT=0,
CONSOLE_COLOR_PROMPT,
CONSOLE_COLOR_USER_INPUT
};
struct console_state {
bool use_color = false;
console_color_t color = CONSOLE_COLOR_DEFAULT;
};
void set_console_color(console_state & con_st, console_color_t color);
#if defined (_WIN32)
void win32_console_init(bool enable_color);
#endif

View file

@ -1,4 +1,4 @@
set(TARGET embedding)
add_executable(${TARGET} embedding.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ggml ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -1,4 +1,4 @@
set(TARGET main)
add_executable(${TARGET} main.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ggml ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -18,58 +18,13 @@
#include <signal.h>
#endif
#if defined (_WIN32)
#pragma comment(lib,"kernel32.lib")
extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
#endif
#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_GREEN "\x1b[32m"
#define ANSI_COLOR_YELLOW "\x1b[33m"
#define ANSI_COLOR_BLUE "\x1b[34m"
#define ANSI_COLOR_MAGENTA "\x1b[35m"
#define ANSI_COLOR_CYAN "\x1b[36m"
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m"
/* Keep track of current color of output, and emit ANSI code if it changes. */
enum console_state {
CONSOLE_STATE_DEFAULT=0,
CONSOLE_STATE_PROMPT,
CONSOLE_STATE_USER_INPUT
};
static console_state con_st = CONSOLE_STATE_DEFAULT;
static bool con_use_color = false;
void set_console_state(console_state new_st) {
if (!con_use_color) return;
// only emit color code if state changed
if (new_st != con_st) {
con_st = new_st;
switch(con_st) {
case CONSOLE_STATE_DEFAULT:
printf(ANSI_COLOR_RESET);
return;
case CONSOLE_STATE_PROMPT:
printf(ANSI_COLOR_YELLOW);
return;
case CONSOLE_STATE_USER_INPUT:
printf(ANSI_BOLD ANSI_COLOR_GREEN);
return;
}
}
}
static console_state con_st;
static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
set_console_state(CONSOLE_STATE_DEFAULT);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) {
if (!is_interacting) {
@ -81,32 +36,6 @@ void sigint_handler(int signo) {
}
#endif
#if defined (_WIN32)
void win32_console_init(void) {
unsigned long dwMode = 0;
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
hConOut = 0;
}
}
if (hConOut) {
// Enable ANSI colors on Windows 10+
if (con_use_color && !(dwMode & 0x4)) {
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
}
// Set console output codepage to UTF8
SetConsoleOutputCP(65001); // CP_UTF8
}
void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
// Set console input codepage to UTF8
SetConsoleCP(65001); // CP_UTF8
}
}
#endif
int main(int argc, char ** argv) {
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";
@ -115,13 +44,12 @@ int main(int argc, char ** argv) {
return 1;
}
// save choice to use color for later
// (note for later: this is a slightly awkward choice)
con_use_color = params.use_color;
con_st.use_color = params.use_color;
#if defined (_WIN32)
win32_console_init();
win32_console_init(params.use_color);
#endif
if (params.perplexity) {
@ -218,7 +146,10 @@ int main(int argc, char ** argv) {
return 1;
}
params.n_keep = std::min(params.n_keep, (int) embd_inp.size());
// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
params.n_keep = (int)embd_inp.size();
}
// prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
@ -226,16 +157,12 @@ int main(int argc, char ** argv) {
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
params.interactive = true;
params.interactive_start = true;
params.antiprompt.push_back("### Instruction:\n\n");
}
// enable interactive mode if reverse prompt is specified
if (params.antiprompt.size() != 0) {
params.interactive = true;
}
if (params.interactive_start) {
// enable interactive mode if reverse prompt or interactive start is specified
if (params.antiprompt.size() != 0 || params.interactive_start) {
params.interactive = true;
}
@ -282,7 +209,8 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");
@ -297,17 +225,18 @@ int main(int argc, char ** argv) {
#endif
" - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n\n");
is_interacting = params.interactive_start || params.instruct;
is_interacting = params.interactive_start;
}
bool input_noecho = false;
bool is_antiprompt = false;
bool input_noecho = false;
int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd;
@ -346,10 +275,10 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
const float top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;
const int32_t top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;
llama_token id = 0;
@ -408,36 +337,38 @@ int main(int argc, char ** argv) {
}
// reset color to default if we there is no pending user input
if (!input_noecho && (int)embd_inp.size() == n_consumed) {
set_console_state(CONSOLE_STATE_DEFAULT);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
// check for reverse prompt
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
set_console_state(CONSOLE_STATE_USER_INPUT);
fflush(stdout);
break;
// check for reverse prompt
if (params.antiprompt.size()) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
is_antiprompt = true;
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
fflush(stdout);
break;
}
}
}
if (n_past > 0 && is_interacting) {
// potentially set color to indicate we are taking user input
set_console_state(CONSOLE_STATE_USER_INPUT);
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
if (params.instruct) {
n_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
printf("\n> ");
}
@ -463,17 +394,29 @@ int main(int argc, char ** argv) {
} while (another_line);
// done taking input, reset color
set_console_state(CONSOLE_STATE_DEFAULT);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
auto line_inp = ::llama_tokenize(ctx, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
// Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back
if (buffer.length() > 1) {
if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
// instruct mode: insert instruction prefix
if (params.instruct && !is_antiprompt) {
n_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
}
auto line_inp = ::llama_tokenize(ctx, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
// instruct mode: insert response suffix
if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}
n_remain -= line_inp.size();
}
n_remain -= line_inp.size();
input_noecho = true; // do not echo this again
}
@ -506,7 +449,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
set_console_state(CONSOLE_STATE_DEFAULT);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
return 0;
}

View file

@ -1,4 +1,4 @@
set(TARGET perplexity)
add_executable(${TARGET} perplexity.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ggml ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -1,15 +1,17 @@
#include "common.h"
#include "llama.h"
std::vector<double> softmax(const std::vector<float>& logits) {
std::vector<double> probs(logits.size());
#include <cmath>
std::vector<float> softmax(const std::vector<float>& logits) {
std::vector<float> probs(logits.size());
float max_logit = logits[0];
for (float v : logits) max_logit = std::max(max_logit, v);
double sum_exp = 0.0;
for (size_t i = 0; i < logits.size(); i++) {
// Subtract the maximum logit value from the current logit value for numerical stability
float logit = logits[i] - max_logit;
double exp_logit = std::exp(logit);
const float logit = logits[i] - max_logit;
const float exp_logit = expf(logit);
sum_exp += exp_logit;
probs[i] = exp_logit;
}
@ -24,14 +26,16 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
int count = 0;
double nll = 0.0;
int seq_count = tokens.size() / params.n_ctx;
double nll = 0.0;
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
for (int i = 0; i < seq_count; ++i) {
int start = i * params.n_ctx;
int end = start + params.n_ctx - 1;
int end = start + params.n_ctx - 1; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
// it is better to always be power of 2 for better performance
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
auto start_t = std::chrono::high_resolution_clock::now();
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
@ -40,7 +44,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
}
auto end_t = std::chrono::high_resolution_clock::now();
if (i == 0) {
double seconds = std::chrono::duration<double>(end_t - start_t).count();
const float seconds = std::chrono::duration<float>(end_t - start_t).count();
printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
}
// We get the logits for all the tokens in the context window (params.n_ctx)
@ -63,7 +67,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
std::vector<float> tok_logits(
logits + j * n_vocab,
logits + (j + 1) * n_vocab);
double prob = softmax(tok_logits)[tokens[start + j + 1]];
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
nll += -std::log(prob);
++count;
}

View file

@ -1,4 +1,4 @@
set(TARGET quantize)
add_executable(${TARGET} quantize.cpp)
target_link_libraries(${TARGET} PRIVATE llama ggml ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -4,8 +4,6 @@
#include <cstdio>
#include <string>
const int QK = 32;
// usage:
// ./llama-quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type
//
@ -39,7 +37,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = ggml_time_us();
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype, QK)) {
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype)) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1;
}
@ -52,8 +50,8 @@ int main(int argc, char ** argv) {
const int64_t t_main_end_us = ggml_time_us();
printf("\n");
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
}
return 0;

928
ggml.c

File diff suppressed because it is too large Load diff

4
ggml.h
View file

@ -748,8 +748,8 @@ enum ggml_opt_result ggml_opt(
// quantization
//
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
//
// system info

View file

@ -321,7 +321,7 @@ static bool llama_model_load(
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files or convert them with convert-unversioned-ggml-to-ggml.py!)\n",
__func__, fname.c_str());
legacy_file_format = true;
}
@ -786,8 +786,8 @@ static bool llama_model_load(
// progress
if (progress_callback) {
double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
double current_progress = (double(i) + current_file_progress) / double(n_parts);
float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset);
float current_progress = (float(i) + current_file_progress) / float(n_parts);
progress_callback(current_progress, progress_callback_user_data);
}
if (model.n_loaded % 8 == 0) {
@ -929,7 +929,7 @@ static bool llama_eval_internal(
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)));
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
@ -1247,12 +1247,12 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
// sampling
//
static void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) {
static void sample_top_k(std::vector<std::pair<float, llama_vocab::id>> & logits_id, int top_k) {
// find the top k tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) {
[](const std::pair<float, llama_vocab::id> & a, const std::pair<float, llama_vocab::id> & b) {
return a.first > b.first;
});
@ -1263,9 +1263,9 @@ static llama_vocab::id llama_sample_top_p_top_k(
llama_context & lctx,
const std::vector<llama_vocab::id> & last_n_tokens,
int top_k,
double top_p,
double temp,
double repeat_penalty) {
float top_p,
float temp,
float repeat_penalty) {
auto & rng = lctx.rng;
const int n_logits = lctx.model.hparams.n_vocab;
@ -1273,17 +1273,17 @@ static llama_vocab::id llama_sample_top_p_top_k(
const auto & logits = lctx.logits;
const auto * plogits = logits.data() + logits.size() - n_logits;
std::vector<std::pair<double, llama_vocab::id>> logits_id;
std::vector<std::pair<float, llama_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const double scale = 1.0/temp;
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0) {
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
@ -1296,18 +1296,18 @@ static llama_vocab::id llama_sample_top_p_top_k(
sample_top_k(logits_id, top_k);
double maxl = -std::numeric_limits<double>::infinity();
float maxl = -std::numeric_limits<float>::infinity();
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
// compute probs for the top k tokens
std::vector<double> probs;
std::vector<float> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
const float p = expf(kv.first - maxl);
probs.push_back(p);
sum += p;
}
@ -1317,8 +1317,8 @@ static llama_vocab::id llama_sample_top_p_top_k(
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
if (top_p < 1.0) {
double cumsum = 0.0;
for (int i = 0; i < (int) probs.size(); i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
@ -1352,7 +1352,7 @@ static llama_vocab::id llama_sample_top_p_top_k(
//
// TODO: reuse code from the llama_model_load() somehow
bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype, int qk) {
static bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype) {
ggml_type type = GGML_TYPE_Q4_1;
switch (itype) {
@ -1575,11 +1575,11 @@ bool llama_model_quantize_internal(const std::string & fname_inp, const std::str
switch (type) {
case GGML_TYPE_Q4_0:
{
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data());
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_1:
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data());
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
default:
{
@ -1597,7 +1597,7 @@ bool llama_model_quantize_internal(const std::string & fname_inp, const std::str
}
for (int i = 0; i < (int) hist_cur.size(); ++i) {
printf("%5.3f ", hist_cur[i] / (float)nelements);
printf("%5.3f ", hist_cur[i] / float(nelements));
}
printf("\n");
} else {
@ -1620,7 +1620,7 @@ bool llama_model_quantize_internal(const std::string & fname_inp, const std::str
printf("%s: hist: ", __func__);
for (int i = 0; i < (int) hist_all.size(); ++i) {
printf("%5.3f ", hist_all[i] / (float)sum_all);
printf("%5.3f ", hist_all[i] / float(sum_all));
}
printf("\n");
}
@ -1718,9 +1718,8 @@ void llama_free(struct llama_context * ctx) {
int llama_model_quantize(
const char * fname_inp,
const char * fname_out,
int itype,
int qk) {
if (!llama_model_quantize_internal(fname_inp, fname_out, itype, qk)) {
int itype) {
if (!llama_model_quantize_internal(fname_inp, fname_out, itype)) {
fprintf(stderr, "%s: failed to quantize\n", __func__);
return 1;
}
@ -1803,9 +1802,9 @@ llama_token llama_sample_top_p_top_k(
const llama_token * last_n_tokens_data,
int last_n_tokens_size,
int top_k,
double top_p,
double temp,
double repeat_penalty) {
float top_p,
float temp,
float repeat_penalty) {
const int64_t t_start_sample_us = ggml_time_us();
llama_token result = 0;
@ -1836,11 +1835,11 @@ void llama_print_timings(struct llama_context * ctx) {
const int32_t n_p_eval = std::max(1, ctx->n_p_eval);
fprintf(stderr, "\n");
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->t_sample_us, n_sample, 1e-3f * ctx->t_sample_us / n_sample);
fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3f * ctx->t_p_eval_us, n_p_eval, 1e-3f * ctx->t_p_eval_us / n_p_eval);
fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->t_eval_us, n_eval, 1e-3f * ctx->t_eval_us / n_eval);
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample);
fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval);
fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval);
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0);
}
void llama_reset_timings(struct llama_context * ctx) {

13
llama.h
View file

@ -6,7 +6,7 @@
#include <stdbool.h>
#ifdef LLAMA_SHARED
# ifdef _WIN32
# ifdef _WIN32 && !defined __MINGW32__
# ifdef LLAMA_BUILD
# define LLAMA_API __declspec(dllexport)
# else
@ -45,7 +45,7 @@ extern "C" {
} llama_token_data;
typedef void (*llama_progress_callback)(double progress, void *ctx);
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
int n_ctx; // text context
@ -81,8 +81,7 @@ extern "C" {
LLAMA_API int llama_model_quantize(
const char * fname_inp,
const char * fname_out,
int itype,
int qk);
int itype);
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
@ -135,9 +134,9 @@ extern "C" {
const llama_token * last_n_tokens_data,
int last_n_tokens_size,
int top_k,
double top_p,
double temp,
double repeat_penalty);
float top_p,
float temp,
float repeat_penalty);
// Performance information
LLAMA_API void llama_print_timings(struct llama_context * ctx);

Binary file not shown.

BIN
main.exe

Binary file not shown.

Binary file not shown.

View file

@ -74,6 +74,10 @@ def main():
args.models_path, model, "ggml-model-f16.bin"
)
if not os.path.isfile(f16_model_path_base):
print(f'The file %s was not found' % f16_model_path_base)
sys.exit(1)
f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*")

1
spm-headers/llama.h Symbolic link
View file

@ -0,0 +1 @@
../llama.h

View file

@ -5,5 +5,6 @@ function(llama_add_test source)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
endfunction()
# llama_add_test(test-double-float.c) # SLOW
llama_add_test(test-quantize.c)
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)

53
tests/test-double-float.c Normal file
View file

@ -0,0 +1,53 @@
// These tests may take a long time!
// They are to prove that conversion from double to float of various functions in ggml.c doesn't affect the result.
// This is done by checking all finite (non-NaN, non-infinite) floats.
#undef NDEBUG
#include <assert.h>
#include <immintrin.h>
#include <math.h>
#include <stdint.h>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdouble-promotion"
// ggml.c::quantize_row_q4_0_reference
inline static uint8_t round_orig(float v0) { return ((int8_t) (round(v0))) + 8; }
// ggml.c::ggml_silu_f32
inline static float silu_orig(float x) {
return x/(1.0 + exp(-x));
}
#pragma GCC diagnostic pop
// ggml.c::quantize_row_q4_0_reference
inline static uint8_t round_float(float v0) { return (int8_t)roundf(v0) + 8; }
// ggml.c::ggml_silu_f32
inline static float silu_float(float x) {
return x/(1.0f + expf(-x));
}
int main(void) {
uint32_t x = UINT32_MAX;
do {
float f = *(float *)&x;
assert(!isfinite(f) || (round_orig(f) == round_float(f)));
} while (x--);
#ifdef __F16C__
// GELU and SILU implementations are used with a FP16 lookup table.
// The original and float-only results are not equal for all inputs after converting to FP16.
// GELU is an approximation anyway (tanh), not tested here.
// For SILU, verify that the results are at least the closest floating point numbers, if the FP16 values don't match.
for (x = 0; x <= UINT16_MAX; x++) {
float f = _cvtsh_ss(x);
const float so = silu_orig(f);
const float sf = silu_float(f);
assert( (_cvtss_sh(so, 0) == _cvtss_sh(sf, 0))
|| (nextafterf(so, sf) == sf)
|| (nextafterf(sf, so) == so));
}
#endif
}

View file

@ -13,7 +13,7 @@ int main(void) {
src[i] = (float)(i + 1);
}
size_t size = ggml_quantize_q4_0(src, dst, QK, QK, QK, hist);
size_t size = ggml_quantize_q4_0(src, dst, QK, QK, hist);
assert(size == 20);
float max_result = ((float *)dst)[0];
float max_expected = src[31] / ((1 << 3) - 1);
@ -24,7 +24,7 @@ int main(void) {
assert(q4_result == q4_expected);
}
size = ggml_quantize_q4_1(src, dst, QK, QK, QK, hist);
size = ggml_quantize_q4_1(src, dst, QK, QK, hist);
assert(size == 24);
float delta_result = ((float *)dst)[0];
float delta_expected = (src[31] - src[0]) / ((1 << 4) - 1);

View file

@ -77,5 +77,7 @@ int main(int argc, char **argv) {
}
}
llama_free(ctx);
return 0;
}