Compute perplexity over prompt (#270)

* Compute perplexity over prompt

* More accurate perplexity calculation - over all logits in the context window (so 512x more tokens!)

* Output all perplexitiies

* Add timing/ETA
This commit is contained in:
Gary Linscott 2023-03-21 09:27:42 -07:00 committed by GitHub
parent 3ab3e6582f
commit 486ae645fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 98 additions and 13 deletions

View file

@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt.push_back(argv[++i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ignore-eos") {
params.ignore_eos = true;
} else if (arg == "--n_parts") {
@ -120,6 +122,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n");
@ -596,7 +599,7 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
char * pdst = (char *) dst;
for (int j = 0; j < n; j += k) {
for (int j = 0; j < n; j += k) {
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
@ -619,7 +622,7 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
*(float *) pd = d;
*(float *) pm = min;
pd += bs;
pd += bs;
pm += bs;
for (int l = 0; l < qk; l += 2) {