removed main.exe to reduce clutter, added support for rep pen in gptj

This commit is contained in:
Concedo 2023-04-04 20:43:13 +08:00
parent 9c0dbbb08b
commit 52de932842
11 changed files with 46 additions and 22 deletions

View file

@ -139,7 +139,7 @@ $(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
default: main llamalib quantize llamalib_blas
default: llamalib quantize llamalib_blas
#
# Build library

View file

@ -26,6 +26,7 @@ static int n_past = 0;
static int n_threads = 4;
static int n_batch = 8;
static std::string modelname;
static std::vector<gpt_vocab::id> last_n_tokens;
static std::vector<gpt_vocab::id> current_context_tokens;
static size_t mem_per_token = 0;
static std::vector<float> logits;
@ -87,9 +88,15 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp
params.top_k = inputs.top_k;
params.top_p = inputs.top_p;
params.temp = inputs.temperature;
params.repeat_last_n = inputs.rep_pen_range;
params.repeat_penalty = inputs.rep_pen;
params.n_batch = n_batch;
params.n_threads = n_threads;
if (params.repeat_last_n < 1)
{
params.repeat_last_n = 1;
}
if (params.top_k < 1)
{
params.top_k = 300; //to disable top_k we actually need to increase this value to a very high number
@ -113,6 +120,10 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp
//determine how much npast we have to rewind from the current state
std::vector<gpt_vocab::id> embd;
int last_n_size = params.repeat_last_n;
last_n_tokens.resize(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
n_past = 0;
//fast forward the past based on identical tokens, stop once a divergence is noted
@ -122,6 +133,7 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp
if (current_context_tokens[i] == embd_inp[i])
{
n_past += 1;
last_n_tokens.push_back(current_context_tokens[i]);
}
else
{
@ -133,6 +145,7 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp
}
}
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past);
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past);
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
@ -203,6 +216,7 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp
const float top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;
if (!startedsampling)
{
@ -218,10 +232,11 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp
// set the logit of the eos token (2) to zero to avoid sampling it
logits[50256] = 0;
//set logits of opening square bracket to zero.
id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
id = gptj_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab),last_n_tokens,repeat_penalty, top_k, top_p, temp, rng);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
current_context_tokens.push_back(id);
}
@ -240,7 +255,9 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp
// some user input remains from prompt or interaction, forward it to processing
while ((int)embd_inp.size() > input_consumed)
{
embd.push_back(embd_inp[input_consumed]);
embd.push_back(embd_inp[input_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[input_consumed]);
current_context_tokens.push_back(embd_inp[input_consumed]);
++input_consumed;
if ((int)embd.size() >= params.n_batch)
@ -252,7 +269,7 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp
}
time2 = timer_check();
printf("\nTime Taken - Processing:%.1fs, Generation:%.1fs, Total:%.1fs", time1, time2, (time1 + time2));
fflush(stdout);
output.status = 1;
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());
return output;

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

View file

@ -263,7 +263,7 @@ generation_outputs llama_generate(const generation_inputs inputs, generation_out
}
time2 = timer_check();
printf("\nTime Taken - Processing:%.1fs, Generation:%.1fs, Total:%.1fs", time1, time2, (time1 + time2));
fflush(stdout);
output.status = 1;
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());
return output;

BIN
main.exe

Binary file not shown.

View file

@ -11,16 +11,19 @@
#include "model_adapter.h"
static clock_t bench_timer = 0;
#include <chrono>
static auto bench_timer = std::chrono::high_resolution_clock().now();
void timer_start()
{
bench_timer = clock();
bench_timer = std::chrono::high_resolution_clock().now();
}
double timer_check()
{
double ticks = clock() - bench_timer;
double time_taken = ((double)ticks) / CLOCKS_PER_SEC;
auto endtime = std::chrono::high_resolution_clock().now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endtime - bench_timer);
double time_taken = duration.count()/1000.0;
return time_taken;
}

View file

@ -294,7 +294,7 @@ ModelLoadResult legacy_gptj_model_load(const std::string & fname, gptj_model_v1
//test for transposition and retry older loader
if(tensor->ne[0]==ne[1] && tensor->ne[1]==ne[0] && should_transpose_layer(name))
{
printf("\nFound a transposed tensor. This could be an older model. Retrying load...");
printf("\nFound a transposed tensor. This could be an older or newer model. Retrying load...");
ggml_v1_free(ctx);
return ModelLoadResult::RETRY_LOAD;
}

View file

@ -289,7 +289,7 @@ ModelLoadResult gptj_model_load(const std::string & fname, gptj_model & model, g
//test for transposition and retry older loader
if(tensor->ne[0]==ne[1] && tensor->ne[1]==ne[0] && should_transpose_layer(name))
{
printf("\nFound a transposed tensor. This could be an older model. Retrying load...");
printf("\nFound a transposed tensor. This could be an older or newer model. Retrying load...");
ggml_free(ctx);
return ModelLoadResult::RETRY_LOAD;
}

Binary file not shown.