mirror of
https://github.com/jart/cosmopolitan.git
synced 2025-02-07 15:03:34 +00:00
8fdb31681a
llama.com can now load weights that use the new file format which was introduced a few weeks ago. Note that, unlike llama.cpp, we will keep support for old file formats in our tool so you don't need to convert your weights when the upstream project makes breaking changes. Please note that using ggjt v3 does make avx2 inference go 5% faster for me.
389 lines
16 KiB
C++
389 lines
16 KiB
C++
/*-*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*-│
|
|
│vi: set net ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi│
|
|
╚──────────────────────────────────────────────────────────────────────────────╝
|
|
│ │
|
|
│ radpajama.com │
|
|
│ Copyright (c) 2023 Ariel Núñez │
|
|
│ Copyright (c) 2023 Georgi Gerganov │
|
|
│ │
|
|
│ Permission is hereby granted, free of charge, to any person obtaining │
|
|
│ a copy of this software and associated documentation files (the │
|
|
│ "Software"), to deal in the Software without restriction, including │
|
|
│ without limitation the rights to use, copy, modify, merge, publish, │
|
|
│ distribute, sublicense, and/or sell copies of the Software, and to │
|
|
│ permit persons to whom the Software is furnished to do so, subject to │
|
|
│ the following conditions: │
|
|
│ │
|
|
│ The above copyright notice and this permission notice shall be │
|
|
│ included in all copies or substantial portions of the Software. │
|
|
│ │
|
|
│ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, │
|
|
│ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF │
|
|
│ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. │
|
|
│ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY │
|
|
│ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, │
|
|
│ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE │
|
|
│ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. │
|
|
│ │
|
|
╚─────────────────────────────────────────────────────────────────────────────*/
|
|
#include "libc/calls/calls.h"
|
|
#include "libc/calls/sigtimedwait.h"
|
|
#include "libc/calls/struct/sigaction.h"
|
|
#include "libc/calls/struct/siginfo.h"
|
|
#include "libc/calls/weirdtypes.h"
|
|
#include "libc/log/log.h"
|
|
#include "libc/runtime/pathconf.h"
|
|
#include "libc/runtime/runtime.h"
|
|
#include "libc/runtime/sysconf.h"
|
|
#include "libc/sysv/consts/f.h"
|
|
#include "libc/sysv/consts/fileno.h"
|
|
#include "libc/sysv/consts/o.h"
|
|
#include "libc/sysv/consts/ok.h"
|
|
#include "libc/sysv/consts/sa.h"
|
|
#include "libc/sysv/consts/sicode.h"
|
|
#include "libc/sysv/consts/ss.h"
|
|
#include "libc/time/time.h"
|
|
#include "third_party/getopt/getopt.h"
|
|
#include "third_party/libcxx/algorithm"
|
|
#include "third_party/libcxx/cassert"
|
|
#include "third_party/libcxx/cinttypes"
|
|
#include "third_party/libcxx/cmath"
|
|
#include "third_party/libcxx/cstdio"
|
|
#include "third_party/libcxx/cstring"
|
|
#include "third_party/libcxx/ctime"
|
|
#include "third_party/libcxx/fstream"
|
|
#include "third_party/libcxx/iostream"
|
|
#include "third_party/libcxx/string"
|
|
#include "third_party/libcxx/vector"
|
|
#include "third_party/musl/crypt.h"
|
|
#include "third_party/musl/lockf.h"
|
|
#include "third_party/radpajama/common-gptneox.h"
|
|
#include "third_party/radpajama/gptneox.h"
|
|
// clang-format off
|
|
|
|
static console_state con_st;
|
|
static gptneox_context ** g_ctx;
|
|
|
|
static bool is_interacting = false;
|
|
|
|
void sigint_handler(int signo) {
|
|
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
|
printf("\n"); // this also force flush stdout.
|
|
if (signo == SIGINT) {
|
|
if (!is_interacting) {
|
|
is_interacting=true;
|
|
} else {
|
|
gptneox_print_timings(*g_ctx);
|
|
_exit(130);
|
|
}
|
|
}
|
|
}
|
|
|
|
int main(int argc, char ** argv) {
|
|
gpt_params params;
|
|
params.model = "./models/ggml-RedPajama-INCITE-Chat-3B-v1-f16.bin";
|
|
|
|
con_st.use_color = true;
|
|
params.n_ctx = 2048;
|
|
params.seed = 1684054676;
|
|
params.use_mmap = true;
|
|
params.use_mlock = true;
|
|
params.memory_f16 = true;
|
|
params.mem_test = false;
|
|
params.interactive = true;
|
|
params.top_k = 30;
|
|
params.top_p = 0.95;
|
|
params.temp = 0.8;
|
|
params.repeat_last_n = 3;
|
|
params.repeat_penalty = 1.1;
|
|
params.instruct = true;
|
|
params.interactive = true;
|
|
|
|
MakeProcessNice();
|
|
ShowCrashReports();
|
|
|
|
if (gpt_params_parse(argc, argv, params) == false) { return 1; }
|
|
|
|
std::mt19937 rng(params.seed);
|
|
gptneox_context * ctx;
|
|
g_ctx = &ctx;
|
|
|
|
{
|
|
auto lparams = gptneox_context_default_params();
|
|
|
|
lparams.n_ctx = params.n_ctx;
|
|
lparams.n_parts = params.n_parts;
|
|
lparams.seed = params.seed;
|
|
lparams.f16_kv = params.memory_f16;
|
|
lparams.use_mmap = params.use_mmap;
|
|
lparams.use_mlock = params.use_mlock;
|
|
|
|
ctx = gptneox_init_from_file(params.model.c_str(), lparams);
|
|
|
|
if (ctx == NULL) {
|
|
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
if (!params.lora_adapter.empty()) {
|
|
int err = gptneox_apply_lora_from_file(ctx,
|
|
params.lora_adapter.c_str(),
|
|
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
|
|
params.n_threads);
|
|
if (err != 0) {
|
|
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
MakeProcessNice();
|
|
ShowCrashReports();
|
|
|
|
// Always interactive for RedPajama chat model
|
|
params.interactive = true;
|
|
|
|
if (params.interactive) {
|
|
struct sigaction sigint_action;
|
|
sigint_action.sa_handler = sigint_handler;
|
|
sigemptyset (&sigint_action.sa_mask);
|
|
sigint_action.sa_flags = 0;
|
|
sigaction(SIGINT, &sigint_action, NULL);
|
|
}
|
|
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
|
|
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
|
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", params.n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
|
fprintf(stderr, "\n\n");
|
|
|
|
// TODO: replace with ring-buffer
|
|
std::vector<gptneox_token> last_n_tokens = std::vector<gptneox_token>();
|
|
|
|
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
|
|
|
|
const int32_t top_k = params.top_k;
|
|
const float top_p = params.top_p;
|
|
const float temp = params.temp;
|
|
const float repeat_penalty = params.repeat_penalty;
|
|
|
|
while (true) {
|
|
is_interacting = true;
|
|
int n_past = 0;
|
|
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
|
|
|
|
if (params.instruct) {
|
|
printf("\n<human>: ");
|
|
}
|
|
|
|
std::string buffer;
|
|
if (!params.input_prefix.empty()) {
|
|
buffer += params.input_prefix;
|
|
printf("%s", buffer.c_str());
|
|
}
|
|
|
|
std::string line;
|
|
bool another_line = true;
|
|
do {
|
|
if (!std::getline(std::cin, line)) {
|
|
// input stream is bad or EOF received
|
|
return 0;
|
|
}
|
|
if (line.empty() || line.back() != '\\') {
|
|
another_line = false;
|
|
} else {
|
|
line.pop_back(); // Remove the continue character
|
|
}
|
|
buffer += line;
|
|
if (another_line) {
|
|
buffer += '\n';
|
|
}
|
|
} while (another_line);
|
|
|
|
is_interacting = false;
|
|
|
|
// done taking input, reset color
|
|
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
|
|
|
// Check for input
|
|
if (buffer.length() <= 0) {
|
|
continue; // Restart loop for input
|
|
}
|
|
|
|
auto prompt_embd = ::gptneox_tokenize(ctx, buffer, false);
|
|
auto embd_inp = std::vector<gptneox_token>();
|
|
|
|
embd_inp.push_back(gptneox_str_to_token(ctx, "<"));
|
|
embd_inp.push_back(gptneox_str_to_token(ctx, "human"));
|
|
embd_inp.push_back(gptneox_str_to_token(ctx, ">:"));
|
|
|
|
embd_inp.insert(embd_inp.end(), prompt_embd.begin(), prompt_embd.end());
|
|
|
|
embd_inp.push_back(gptneox_str_to_token(ctx, "\n"));
|
|
embd_inp.push_back(gptneox_str_to_token(ctx, "<"));
|
|
embd_inp.push_back(gptneox_str_to_token(ctx, "bot"));
|
|
embd_inp.push_back(gptneox_str_to_token(ctx, ">:"));
|
|
|
|
// How many tokens to generate - check if theres space in context for atleast one token (or batch size tokens?)
|
|
auto inp_size = embd_inp.size();
|
|
auto space = params.n_ctx - inp_size;
|
|
if(space <= 0) {
|
|
fprintf(stderr, "%s : input too long\n", __func__);
|
|
continue;
|
|
}
|
|
// Send batches to eval
|
|
while (n_past < inp_size) {
|
|
auto remaining = inp_size - n_past;
|
|
int n_eval = params.n_batch < remaining ? params.n_batch : remaining;
|
|
if (gptneox_eval(ctx, &embd_inp[n_past], n_eval, n_past, params.n_threads)) {
|
|
fprintf(stderr, "<bot>: %s : failed to eval\n", __func__);
|
|
return 1;
|
|
}
|
|
n_past += n_eval;
|
|
}
|
|
|
|
const int n_ctx = gptneox_n_ctx(ctx);
|
|
const int n_vocab = gptneox_n_vocab(ctx);
|
|
|
|
const float temp = params.temp;
|
|
const int32_t top_k = params.top_k <= 0 ? gptneox_n_vocab(ctx) : params.top_k;
|
|
const float top_p = params.top_p;
|
|
const float tfs_z = params.tfs_z;
|
|
const float typical_p = params.typical_p;
|
|
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
|
const float repeat_penalty = params.repeat_penalty;
|
|
const float alpha_presence = params.presence_penalty;
|
|
const float alpha_frequency = params.frequency_penalty;
|
|
const int mirostat = params.mirostat;
|
|
const float mirostat_tau = params.mirostat_tau;
|
|
const float mirostat_eta = params.mirostat_eta;
|
|
const bool penalize_nl = params.penalize_nl;
|
|
|
|
// Eval until space runs out
|
|
auto out_count = 0;
|
|
|
|
printf("<bot>:");
|
|
while (space > 0) {
|
|
// Get token
|
|
gptneox_token id = 0;
|
|
|
|
{
|
|
auto logits = gptneox_get_logits(ctx);
|
|
|
|
// Apply params.logit_bias map
|
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
|
logits[it->first] += it->second;
|
|
}
|
|
|
|
std::vector<gptneox_token_data> candidates;
|
|
candidates.reserve(n_vocab);
|
|
for (gptneox_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
candidates.emplace_back(gptneox_token_data{token_id, logits[token_id], 0.0f});
|
|
}
|
|
|
|
gptneox_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
|
|
// Apply penalties
|
|
gptneox_token nl_token = gptneox_str_to_token(ctx, "\n");
|
|
float nl_logit = logits[nl_token];
|
|
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
|
gptneox_sample_repetition_penalty(ctx, &candidates_p,
|
|
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
last_n_repeat, repeat_penalty);
|
|
gptneox_sample_frequency_and_presence_penalties(ctx, &candidates_p,
|
|
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
last_n_repeat, alpha_frequency, alpha_presence);
|
|
if (!penalize_nl) {
|
|
logits[nl_token] = nl_logit;
|
|
}
|
|
|
|
if (temp <= 0) {
|
|
// Greedy sampling
|
|
id = gptneox_sample_token_greedy(ctx, &candidates_p);
|
|
} else {
|
|
if (mirostat == 1) {
|
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
|
const int mirostat_m = 100;
|
|
gptneox_sample_temperature(ctx, &candidates_p, temp);
|
|
id = gptneox_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
|
} else if (mirostat == 2) {
|
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
|
gptneox_sample_temperature(ctx, &candidates_p, temp);
|
|
id = gptneox_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
|
} else {
|
|
// Temperature sampling
|
|
gptneox_sample_top_k(ctx, &candidates_p, top_k, 1);
|
|
gptneox_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
|
|
gptneox_sample_typical(ctx, &candidates_p, typical_p, 1);
|
|
gptneox_sample_top_p(ctx, &candidates_p, top_p, 1);
|
|
gptneox_sample_temperature(ctx, &candidates_p, temp);
|
|
id = gptneox_sample_token(ctx, &candidates_p);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Inc out count and dec space
|
|
out_count += 1;
|
|
space -= 1;
|
|
// Repeat tokens update
|
|
last_n_tokens.push_back(id);
|
|
if (last_n_tokens.size() > params.repeat_last_n) {
|
|
last_n_tokens.erase(last_n_tokens.begin());
|
|
}
|
|
// Redpajama: check if the interactive is done.
|
|
//std::cout<<" last_n_tokens.size: "<< last_n_tokens[0] <<" "<< last_n_tokens[1] <<" "<< last_n_tokens[2] << std::endl;
|
|
if (last_n_tokens.size()==3 && last_n_tokens[0]==gptneox_str_to_token(ctx, "<")
|
|
&& last_n_tokens[1]==gptneox_str_to_token(ctx, "human") && last_n_tokens[2]==gptneox_str_to_token(ctx, ">:")){
|
|
space = 0;
|
|
continue;
|
|
}
|
|
|
|
// Check for eos - end early - check eos before bos in case they are the same
|
|
if (id == gptneox_token_eos()) {
|
|
space = 0;
|
|
continue;
|
|
}
|
|
// Check for bos - skip callback if so
|
|
if (id == gptneox_token_bos()) {
|
|
continue;
|
|
}
|
|
|
|
if (last_n_tokens[2]==gptneox_str_to_token(ctx, "<")){
|
|
;
|
|
}
|
|
else if (last_n_tokens[2]==gptneox_str_to_token(ctx, "human")){
|
|
if (last_n_tokens[1]==gptneox_str_to_token(ctx, "<")){
|
|
;
|
|
}
|
|
else{
|
|
printf("%s", gptneox_token_to_str(ctx, id));
|
|
}
|
|
}
|
|
else if (last_n_tokens[1]==gptneox_str_to_token(ctx, "<")){
|
|
printf("<");
|
|
printf("%s", gptneox_token_to_str(ctx, id));
|
|
}
|
|
else{
|
|
printf("%s", gptneox_token_to_str(ctx, id));
|
|
}
|
|
fflush(stdout);
|
|
// Check if we need to run another eval
|
|
if (space > 0) {
|
|
// Send generated token back into model for next generation
|
|
if (gptneox_eval(ctx, &id, 1, n_past, params.n_threads)) {
|
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
return 1;
|
|
}
|
|
// Increment past count
|
|
n_past += 1;
|
|
}
|
|
// Check for user interrupt
|
|
if (is_interacting) { space = 0; }
|
|
}
|
|
printf("\n");
|
|
fflush(stdout);
|
|
}
|
|
|
|
gptneox_print_timings(ctx);
|
|
gptneox_free(ctx);
|
|
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
|
return 0;
|
|
}
|