Add BPE dropout support, use it in training.

This commit is contained in:
Howard Su 2023-07-02 22:57:14 +08:00
parent 46088f7231
commit 685d236d8b
6 changed files with 29 additions and 10 deletions

View file

@ -527,7 +527,7 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
std::vector<llama_token> res(text.size() + (int) add_bos); std::vector<llama_token> res(text.size() + (int) add_bos);
const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos, 0.0);
assert(n >= 0); assert(n >= 0);
res.resize(n); res.resize(n);

View file

@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
auto tokens = std::vector<llama_token>(params.n_ctx); auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true); auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true, 0.0);
if (n_prompt_tokens < 1) { if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);

View file

@ -2187,7 +2187,7 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
out.resize(buf.size()); out.resize(buf.size());
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false); int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false, 0.1f);
if (n_tokens >= 0) { if (n_tokens >= 0) {
out.resize(n_tokens); out.resize(n_tokens);
} }

View file

@ -48,6 +48,7 @@
#include <mutex> #include <mutex>
#include <sstream> #include <sstream>
#include <numeric> #include <numeric>
#include <random>
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
@ -1717,7 +1718,7 @@ struct llama_sp_bigram {
// original implementation: // original implementation:
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
struct llama_tokenizer { struct llama_tokenizer {
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {} llama_tokenizer(const llama_vocab & vocab, float dropout): vocab_(vocab), dropout_(dropout) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
// split string into utf8 chars // split string into utf8 chars
@ -1759,6 +1760,9 @@ struct llama_tokenizer {
right_sym.n = 0; right_sym.n = 0;
//printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size); //printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
if (skip_merge()) {
continue;
}
// remove the right sym from the chain // remove the right sym from the chain
left_sym.next = right_sym.next; left_sym.next = right_sym.next;
@ -1814,13 +1818,26 @@ private:
work_queue_.push(bigram); work_queue_.push(bigram);
} }
bool skip_merge()
{
std::uniform_real_distribution<> gen(0.0, 1.0);
if (dropout_ <= 0.0) {
return false;
}
if (dropout_ >= 1.0)
return true;
return gen(rng) < dropout_;
}
const llama_vocab & vocab_; const llama_vocab & vocab_;
std::vector<llama_sp_symbol> symbols_; std::vector<llama_sp_symbol> symbols_;
llama_sp_bigram::queue work_queue_; llama_sp_bigram::queue work_queue_;
float dropout_;
std::mt19937 rng;
}; };
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos, float dropout) {
llama_tokenizer tokenizer(vocab); llama_tokenizer tokenizer(vocab, dropout);
std::vector<llama_vocab::id> output; std::vector<llama_vocab::id> output;
if (text.empty()) { if (text.empty()) {
@ -3407,8 +3424,9 @@ int llama_tokenize(
const char * text, const char * text,
llama_token * tokens, llama_token * tokens,
int n_max_tokens, int n_max_tokens,
bool add_bos) { bool add_bos,
auto res = llama_tokenize(ctx->vocab, text, add_bos); float dropout) {
auto res = llama_tokenize(ctx->vocab, text, add_bos, dropout);
if (n_max_tokens < (int) res.size()) { if (n_max_tokens < (int) res.size()) {
fprintf(stderr, "%s: too many tokens\n", __func__); fprintf(stderr, "%s: too many tokens\n", __func__);

View file

@ -252,7 +252,8 @@ extern "C" {
const char * text, const char * text,
llama_token * tokens, llama_token * tokens,
int n_max_tokens, int n_max_tokens,
bool add_bos); bool add_bos,
float dropout);
LLAMA_API int llama_n_vocab(const struct llama_context * ctx); LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx);

View file

@ -64,7 +64,7 @@ int main(int argc, char **argv) {
for (const auto & test_kv : k_tests()) { for (const auto & test_kv : k_tests()) {
std::vector<llama_token> res(test_kv.first.size()); std::vector<llama_token> res(test_kv.first.size());
const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true); const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true, 0.0);
res.resize(n); res.resize(n);
bool correct = res.size() == test_kv.second.size(); bool correct = res.size() == test_kv.second.size();