trace logits to a file
This commit is contained in:
parent
34ab526843
commit
840645dea7
8 changed files with 212 additions and 0 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -4,6 +4,7 @@
|
|||
.vs/
|
||||
.vscode/
|
||||
.DS_Store
|
||||
__pycache__
|
||||
|
||||
build/
|
||||
build-em/
|
||||
|
|
|
@ -169,6 +169,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.input_prefix = argv[i];
|
||||
} else if (arg == "--trace") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.trace_fn = argv[i];
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
gpt_print_usage(argc, argv, params);
|
||||
|
@ -224,6 +230,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
|
||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||
fprintf(stderr, " --trace FNAME save the the model logits during evaluation to a binary file\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
@ -256,3 +263,43 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
|
|||
|
||||
return res;
|
||||
}
|
||||
|
||||
// Open the trace file and write the header in the binary format: magic:int version:int n_vocab:int
|
||||
std::ofstream trace_open(const gpt_params & params, struct llama_context * ctx) {
|
||||
std::ofstream trace_ofs;
|
||||
|
||||
const uint32_t n_vocab = llama_n_vocab(ctx);
|
||||
if(n_vocab <= 0) {
|
||||
return trace_ofs;
|
||||
}
|
||||
const auto& trace_fn = params.trace_fn;
|
||||
trace_ofs.open(trace_fn, std::ios::binary);
|
||||
if(trace_ofs.is_open() && trace_ofs.good()) {
|
||||
fprintf(stderr, "Tracing evaluation to: '%s'\n", trace_fn.c_str());
|
||||
trace_ofs.write(reinterpret_cast<const char*>(&LLAMA_TRACE_MAGIC), sizeof(uint32_t));
|
||||
trace_ofs.write(reinterpret_cast<const char*>(&LLAMA_TRACE_VERSION), sizeof(uint32_t));
|
||||
trace_ofs.write(reinterpret_cast<const char*>(&n_vocab), sizeof(uint32_t));
|
||||
} else {
|
||||
fprintf(stderr, "Could not open trace file: '%s'\n", trace_fn.c_str());
|
||||
trace_ofs.close();
|
||||
}
|
||||
return trace_ofs;
|
||||
}
|
||||
|
||||
// Write a record using the binary format: N:int {N}token_id:int {N*n_vocab}logits:float
|
||||
void trace_write_record(
|
||||
std::ofstream & out,
|
||||
const std::vector<llama_token> & embd,
|
||||
struct llama_context * ctx) {
|
||||
|
||||
const uint32_t N = embd.size();
|
||||
const int n_vocab = llama_n_vocab(ctx);
|
||||
const float * logits = llama_get_logits(ctx);
|
||||
if(!out.is_open() || out.bad() || N == 0 || n_vocab <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.write(reinterpret_cast<const char*>(&N), sizeof(uint32_t));
|
||||
out.write(reinterpret_cast<const char*>(embd.data()), sizeof(llama_token)*N);
|
||||
out.write(reinterpret_cast<const char*>(logits), sizeof(float)*N*n_vocab);
|
||||
}
|
|
@ -32,6 +32,7 @@ struct gpt_params {
|
|||
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
||||
std::string prompt = "";
|
||||
std::string input_prefix = ""; // string to prefix user inputs with
|
||||
std::string trace_fn = "";
|
||||
|
||||
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
|
@ -63,3 +64,19 @@ std::string gpt_random_prompt(std::mt19937 & rng);
|
|||
//
|
||||
|
||||
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
|
||||
|
||||
//
|
||||
// Trace utils
|
||||
//
|
||||
|
||||
static constexpr uint32_t LLAMA_TRACE_VERSION = 0;
|
||||
static constexpr uint32_t LLAMA_TRACE_MAGIC = 0x67676d74; // 'ggmt' in hex
|
||||
|
||||
// Open format: magic:int version:int n_vocab:int
|
||||
std::ofstream trace_open(const gpt_params & params, struct llama_context * ctx);
|
||||
|
||||
// Write a record using the binary format: N:int {N}token_id:int {N*n_vocab}logits:float
|
||||
void trace_write_record(
|
||||
std::ofstream & out,
|
||||
const std::vector<llama_token> & embd,
|
||||
struct llama_context * ctx);
|
||||
|
|
|
@ -169,6 +169,7 @@ int main(int argc, char ** argv) {
|
|||
lparams.n_parts = params.n_parts;
|
||||
lparams.seed = params.seed;
|
||||
lparams.f16_kv = params.memory_f16;
|
||||
lparams.logits_all = !params.trace_fn.empty();
|
||||
lparams.use_mlock = params.use_mlock;
|
||||
|
||||
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
||||
|
@ -205,6 +206,8 @@ int main(int argc, char ** argv) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
std::ofstream trace_ofs = trace_open(params, ctx);
|
||||
|
||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||
params.prompt.insert(0, 1, ' ');
|
||||
|
||||
|
@ -339,6 +342,7 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
trace_write_record(trace_ofs, embd, ctx);
|
||||
}
|
||||
|
||||
n_past += embd.size();
|
||||
|
@ -502,6 +506,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_print_timings(ctx);
|
||||
llama_free(ctx);
|
||||
trace_ofs.close();
|
||||
|
||||
set_console_state(CONSOLE_STATE_DEFAULT);
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
std::vector<double> softmax(const std::vector<float>& logits) {
|
||||
std::vector<double> probs(logits.size());
|
||||
float max_logit = logits[0];
|
||||
|
@ -27,6 +29,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
double nll = 0.0;
|
||||
int seq_count = tokens.size() / params.n_ctx;
|
||||
|
||||
std::ofstream trace_ofs = trace_open(params, ctx);
|
||||
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
|
||||
|
||||
for (int i = 0; i < seq_count; ++i) {
|
||||
|
@ -57,6 +60,8 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
// process the entire prompt.
|
||||
|
||||
auto logits = llama_get_logits(ctx);
|
||||
trace_write_record(trace_ofs, embd, ctx);
|
||||
|
||||
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
|
||||
// Calculate probability of next token, given the previous ones.
|
||||
int n_vocab = llama_n_vocab(ctx);
|
||||
|
@ -72,6 +77,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
fflush(stdout);
|
||||
}
|
||||
printf("\n");
|
||||
trace_ofs.close();
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
|
|
1
examples/traceparser/__init__.py
Normal file
1
examples/traceparser/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .parser import open_trace
|
70
examples/traceparser/__main__.py
Normal file
70
examples/traceparser/__main__.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
from . import open_trace
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Upgrade old ggml model files to the current format')
|
||||
parser.add_argument('trace_file', help='tracefile to read')
|
||||
parser.add_argument('--tokenizer', help='path to LLaMA tokenizer.model file',
|
||||
dest='tokenizer_model_file', default='models/tokenizer.model')
|
||||
parser.add_argument('--temp', help='Sampling temperature',
|
||||
dest='temperature', default=0.8, type=float)
|
||||
parser.add_argument('--top_k', help='top k tokens to sample', type=int)
|
||||
parser.add_argument('--top_p', help='nucleus probability', type=float, default=1.0)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def top_k_indices(logits, k):
|
||||
idxs = np.argpartition(logits, -k)[-k:]
|
||||
idxs = idxs[np.argsort(logits[idxs])][::-1]
|
||||
return idxs
|
||||
|
||||
def process_logits(logits, temp):
|
||||
logits = logits / temp
|
||||
logp = logits - logits.max()
|
||||
p = np.exp(logp)
|
||||
sum_p = p.sum()
|
||||
entropy = -(p * logp).sum() / sum_p + np.log(sum_p)
|
||||
p /= sum_p
|
||||
#entropy = -(p * np.log(p)).sum()
|
||||
return p, entropy
|
||||
|
||||
def top_p(p, top_p):
|
||||
if top_p < 1:
|
||||
cumsum = 0.
|
||||
for i in range(len(p)):
|
||||
cumsum += p[i]
|
||||
if cumsum >= top_p:
|
||||
return i + 1
|
||||
return len(p)
|
||||
|
||||
def replicate_sampler(tokens, args, max_print=10):
|
||||
log2 = np.log(2)
|
||||
tokenizer = SentencePieceProcessor(args.tokenizer_model_file)
|
||||
piece_repr = lambda tokid: repr(tokenizer.id_to_piece(int(tokid)))
|
||||
for tokens, logits_arrs in f:
|
||||
for tokid, logits in zip(tokens, logits_arrs):
|
||||
idxs = None
|
||||
if args.top_k is not None:
|
||||
idxs = top_k_indices(logits, args.top_k)
|
||||
else:
|
||||
idxs = np.argsort(logits)[::-1]
|
||||
logits = logits[idxs]
|
||||
p, entropy = process_logits(logits, args.temperature)
|
||||
|
||||
n_top_p = top_p(p, args.top_p)
|
||||
logits = logits[:n_top_p]
|
||||
idxs = idxs[:n_top_p]
|
||||
|
||||
print(f'in:{piece_repr(tokid):10} logits: mean={logits.mean()=:5.2f} max={logits[0]:5.2f} entropy={entropy*log2:.2f} bits n={len(idxs)}')
|
||||
print(' '*13, ' '.join(f'{piece_repr(candtok)}:{prob:.2f}' for candtok, prob in zip(idxs[:max_print], p)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
with open_trace(args.trace_file) as f:
|
||||
print(f'n_vocab={f.n_vocab}')
|
||||
replicate_sampler(f, args)
|
65
examples/traceparser/parser.py
Normal file
65
examples/traceparser/parser.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import struct
|
||||
import mmap
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def open_trace(fn):
|
||||
base_header_fmt = "i" * 2
|
||||
file = open(fn, "rb")
|
||||
magic, version = struct.unpack(base_header_fmt, file.read(struct.calcsize(base_header_fmt)))
|
||||
if magic != 0x67676d74:
|
||||
raise ValueError('Invalid file magic. Must be a llama.cpp trace file')
|
||||
parser_cls = TraceParserBase._parsers.get(version)
|
||||
if parser_cls is None:
|
||||
raise ValueError(f'Unknown version {version}')
|
||||
return parser_cls(file)
|
||||
|
||||
class TraceParserBase:
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
self.mmap = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
self.pos = file.tell() # Skip magic and version header
|
||||
self.size = self.mmap.size()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.mmap.close()
|
||||
self.file.close()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.pos >= self.size:
|
||||
raise StopIteration
|
||||
return self.parse_record()
|
||||
|
||||
class TraceParserV0(TraceParserBase):
|
||||
def __init__(self, file):
|
||||
super().__init__(file)
|
||||
header_fmt = 'i' # n_vocab
|
||||
self.n_vocab, = struct.unpack_from(header_fmt, self.mmap, self.pos)
|
||||
self.pos += struct.calcsize(header_fmt)
|
||||
|
||||
def parse_record(self):
|
||||
pos = self.pos
|
||||
n_vocab = self.n_vocab
|
||||
|
||||
header_fmt = 'i' # n_tokens
|
||||
n_tokens, = struct.unpack_from(header_fmt, self.mmap, pos)
|
||||
pos += struct.calcsize(header_fmt)
|
||||
tokens = np.frombuffer(self.mmap, dtype=np.int32, count=n_tokens, offset=pos)
|
||||
pos += tokens.itemsize * tokens.size
|
||||
logits = np.frombuffer(self.mmap, dtype=np.float32, count=n_tokens * n_vocab, offset=pos)
|
||||
pos += logits.itemsize * logits.size
|
||||
|
||||
assert pos <= self.size
|
||||
self.pos = pos
|
||||
return tokens, logits.reshape((n_tokens, n_vocab))
|
||||
|
||||
TraceParserBase._parsers = {
|
||||
0: TraceParserV0
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue