Add BPE dropout support, use it in training.
This commit is contained in:
parent
46088f7231
commit
685d236d8b
6 changed files with 29 additions and 10 deletions
|
@ -527,7 +527,7 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
|
|||
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
||||
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
|
||||
std::vector<llama_token> res(text.size() + (int) add_bos);
|
||||
const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
|
||||
const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos, 0.0);
|
||||
assert(n >= 0);
|
||||
res.resize(n);
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
auto tokens = std::vector<llama_token>(params.n_ctx);
|
||||
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
|
||||
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true, 0.0);
|
||||
|
||||
if (n_prompt_tokens < 1) {
|
||||
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
|
||||
|
|
|
@ -2187,7 +2187,7 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
|
|||
|
||||
out.resize(buf.size());
|
||||
|
||||
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false);
|
||||
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false, 0.1f);
|
||||
if (n_tokens >= 0) {
|
||||
out.resize(n_tokens);
|
||||
}
|
||||
|
|
28
llama.cpp
28
llama.cpp
|
@ -48,6 +48,7 @@
|
|||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
|
@ -1717,7 +1718,7 @@ struct llama_sp_bigram {
|
|||
// original implementation:
|
||||
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||
struct llama_tokenizer {
|
||||
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
|
||||
llama_tokenizer(const llama_vocab & vocab, float dropout): vocab_(vocab), dropout_(dropout) {}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
// split string into utf8 chars
|
||||
|
@ -1759,6 +1760,9 @@ struct llama_tokenizer {
|
|||
right_sym.n = 0;
|
||||
|
||||
//printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
|
||||
if (skip_merge()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// remove the right sym from the chain
|
||||
left_sym.next = right_sym.next;
|
||||
|
@ -1814,13 +1818,26 @@ private:
|
|||
work_queue_.push(bigram);
|
||||
}
|
||||
|
||||
bool skip_merge()
|
||||
{
|
||||
std::uniform_real_distribution<> gen(0.0, 1.0);
|
||||
if (dropout_ <= 0.0) {
|
||||
return false;
|
||||
}
|
||||
if (dropout_ >= 1.0)
|
||||
return true;
|
||||
return gen(rng) < dropout_;
|
||||
}
|
||||
|
||||
const llama_vocab & vocab_;
|
||||
std::vector<llama_sp_symbol> symbols_;
|
||||
llama_sp_bigram::queue work_queue_;
|
||||
float dropout_;
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
||||
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
||||
llama_tokenizer tokenizer(vocab);
|
||||
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos, float dropout) {
|
||||
llama_tokenizer tokenizer(vocab, dropout);
|
||||
std::vector<llama_vocab::id> output;
|
||||
|
||||
if (text.empty()) {
|
||||
|
@ -3407,8 +3424,9 @@ int llama_tokenize(
|
|||
const char * text,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos) {
|
||||
auto res = llama_tokenize(ctx->vocab, text, add_bos);
|
||||
bool add_bos,
|
||||
float dropout) {
|
||||
auto res = llama_tokenize(ctx->vocab, text, add_bos, dropout);
|
||||
|
||||
if (n_max_tokens < (int) res.size()) {
|
||||
fprintf(stderr, "%s: too many tokens\n", __func__);
|
||||
|
|
3
llama.h
3
llama.h
|
@ -252,7 +252,8 @@ extern "C" {
|
|||
const char * text,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
bool add_bos,
|
||||
float dropout);
|
||||
|
||||
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||
|
|
|
@ -64,7 +64,7 @@ int main(int argc, char **argv) {
|
|||
|
||||
for (const auto & test_kv : k_tests()) {
|
||||
std::vector<llama_token> res(test_kv.first.size());
|
||||
const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true);
|
||||
const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true, 0.0);
|
||||
res.resize(n);
|
||||
|
||||
bool correct = res.size() == test_kv.second.size();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue