Add BPE dropout support, use it in training.
This commit is contained in:
parent
46088f7231
commit
685d236d8b
6 changed files with 29 additions and 10 deletions
|
@ -527,7 +527,7 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
|
||||||
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
||||||
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
|
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
|
||||||
std::vector<llama_token> res(text.size() + (int) add_bos);
|
std::vector<llama_token> res(text.size() + (int) add_bos);
|
||||||
const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
|
const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos, 0.0);
|
||||||
assert(n >= 0);
|
assert(n >= 0);
|
||||||
res.resize(n);
|
res.resize(n);
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
auto tokens = std::vector<llama_token>(params.n_ctx);
|
auto tokens = std::vector<llama_token>(params.n_ctx);
|
||||||
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
|
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true, 0.0);
|
||||||
|
|
||||||
if (n_prompt_tokens < 1) {
|
if (n_prompt_tokens < 1) {
|
||||||
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
|
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
|
||||||
|
|
|
@ -2187,7 +2187,7 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
|
||||||
|
|
||||||
out.resize(buf.size());
|
out.resize(buf.size());
|
||||||
|
|
||||||
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false);
|
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false, 0.1f);
|
||||||
if (n_tokens >= 0) {
|
if (n_tokens >= 0) {
|
||||||
out.resize(n_tokens);
|
out.resize(n_tokens);
|
||||||
}
|
}
|
||||||
|
|
28
llama.cpp
28
llama.cpp
|
@ -48,6 +48,7 @@
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <random>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
|
@ -1717,7 +1718,7 @@ struct llama_sp_bigram {
|
||||||
// original implementation:
|
// original implementation:
|
||||||
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||||
struct llama_tokenizer {
|
struct llama_tokenizer {
|
||||||
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
|
llama_tokenizer(const llama_vocab & vocab, float dropout): vocab_(vocab), dropout_(dropout) {}
|
||||||
|
|
||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
// split string into utf8 chars
|
// split string into utf8 chars
|
||||||
|
@ -1759,6 +1760,9 @@ struct llama_tokenizer {
|
||||||
right_sym.n = 0;
|
right_sym.n = 0;
|
||||||
|
|
||||||
//printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
|
//printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
|
||||||
|
if (skip_merge()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// remove the right sym from the chain
|
// remove the right sym from the chain
|
||||||
left_sym.next = right_sym.next;
|
left_sym.next = right_sym.next;
|
||||||
|
@ -1814,13 +1818,26 @@ private:
|
||||||
work_queue_.push(bigram);
|
work_queue_.push(bigram);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool skip_merge()
|
||||||
|
{
|
||||||
|
std::uniform_real_distribution<> gen(0.0, 1.0);
|
||||||
|
if (dropout_ <= 0.0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (dropout_ >= 1.0)
|
||||||
|
return true;
|
||||||
|
return gen(rng) < dropout_;
|
||||||
|
}
|
||||||
|
|
||||||
const llama_vocab & vocab_;
|
const llama_vocab & vocab_;
|
||||||
std::vector<llama_sp_symbol> symbols_;
|
std::vector<llama_sp_symbol> symbols_;
|
||||||
llama_sp_bigram::queue work_queue_;
|
llama_sp_bigram::queue work_queue_;
|
||||||
|
float dropout_;
|
||||||
|
std::mt19937 rng;
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos, float dropout) {
|
||||||
llama_tokenizer tokenizer(vocab);
|
llama_tokenizer tokenizer(vocab, dropout);
|
||||||
std::vector<llama_vocab::id> output;
|
std::vector<llama_vocab::id> output;
|
||||||
|
|
||||||
if (text.empty()) {
|
if (text.empty()) {
|
||||||
|
@ -3407,8 +3424,9 @@ int llama_tokenize(
|
||||||
const char * text,
|
const char * text,
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int n_max_tokens,
|
int n_max_tokens,
|
||||||
bool add_bos) {
|
bool add_bos,
|
||||||
auto res = llama_tokenize(ctx->vocab, text, add_bos);
|
float dropout) {
|
||||||
|
auto res = llama_tokenize(ctx->vocab, text, add_bos, dropout);
|
||||||
|
|
||||||
if (n_max_tokens < (int) res.size()) {
|
if (n_max_tokens < (int) res.size()) {
|
||||||
fprintf(stderr, "%s: too many tokens\n", __func__);
|
fprintf(stderr, "%s: too many tokens\n", __func__);
|
||||||
|
|
3
llama.h
3
llama.h
|
@ -252,7 +252,8 @@ extern "C" {
|
||||||
const char * text,
|
const char * text,
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int n_max_tokens,
|
int n_max_tokens,
|
||||||
bool add_bos);
|
bool add_bos,
|
||||||
|
float dropout);
|
||||||
|
|
||||||
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
|
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
|
||||||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||||
|
|
|
@ -64,7 +64,7 @@ int main(int argc, char **argv) {
|
||||||
|
|
||||||
for (const auto & test_kv : k_tests()) {
|
for (const auto & test_kv : k_tests()) {
|
||||||
std::vector<llama_token> res(test_kv.first.size());
|
std::vector<llama_token> res(test_kv.first.size());
|
||||||
const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true);
|
const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true, 0.0);
|
||||||
res.resize(n);
|
res.resize(n);
|
||||||
|
|
||||||
bool correct = res.size() == test_kv.second.size();
|
bool correct = res.size() == test_kv.second.size();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue