This commit is contained in:
weirui.kwr 2023-11-03 15:44:53 +08:00
parent 629f917cd6
commit 66b24d9f31
2 changed files with 24808 additions and 170 deletions

View file

@ -5,6 +5,13 @@
#include <sstream> #include <sstream>
#include <functional> #include <functional>
#include <nlohmann/json.hpp>
#include <fstream>
#include <iostream>
using json = nlohmann::json;
struct random_normal_distribution { struct random_normal_distribution {
std::mt19937 gen; std::mt19937 gen;
std::normal_distribution<float> rd; std::normal_distribution<float> rd;
@ -706,6 +713,65 @@ void save_train_state_gguf(struct gguf_context * fctx, struct train_state * trai
save_opt_context_gguf(fctx, train->opt); save_opt_context_gguf(fctx, train->opt);
} }
struct masked_locations {
// 数据的index范围是[0, batch_size * seq_length-1]
int ir;
// response开始的位置会从这个位置开始计算loss范围是[0, context_length-1]
int ires;
// 这个位置开始都是padding token会在这个位置前结束计算loss(这个位置不计算loss),范围是[0, context_length-1]
int ipad;
};
struct Data {
std::string prompt;
std::string response;
// std::vector<int> mask;
};
std::vector<Data> read_json(std::string filename) {
std::vector<Data> data;
std::ifstream file(filename);
if (file.is_open()) {
json j;
file >> j;
for (auto& item : j) {
Data d;
std::string prompt = item["prompt"];
std::string response = item["response"];
d.prompt = prompt;
d.response = response;
data.push_back(d);
}
file.close();
} else {
std::cout << "Unable to open file" << filename << std::endl;
}
return data;
}
int tokenize_buffer(llama_context* lctx, const std::vector<char>& buffer, std::vector<int>& out_tokens) {
int n_tokens = llama_tokenize(
llama_get_model(lctx),
buffer.data(),
(int) buffer.size(),
out_tokens.data(),
(int) out_tokens.size(),
false, false);
if (n_tokens < 0) {
out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(
llama_get_model(lctx),
buffer.data(),
(int) buffer.size(),
out_tokens.data(),
(int) out_tokens.size(),
false, false);
}
if (n_tokens >= 0) {
out_tokens.resize(n_tokens);
}
return n_tokens;
}
struct llama_file { struct llama_file {
// use FILE * so we don't have to re-open the file to mmap // use FILE * so we don't have to re-open the file to mmap
@ -827,192 +893,75 @@ size_t tokenize_file(
std::vector<llama_token> & out_tokens, std::vector<llama_token> & out_tokens,
std::vector<size_t> & out_samples_begin, std::vector<size_t> & out_samples_begin,
std::vector<size_t> & out_samples_size) { std::vector<size_t> & out_samples_size) {
struct llama_file f(filename, "rb");
if (f.size == 0) { std::vector<Data> samples = read_json(filename);
out_tokens.clear();
out_samples_begin.clear();
out_samples_size.clear();
printf("%s: warning: empty or not existing training data file '%s'\n",
__func__, filename);
return out_tokens.size();
}
// account for possible leading whitespace that will be added by tokenizer // account for possible leading whitespace that will be added by tokenizer
// e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12] // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
const int n_max_tokens_overhead = 1; const int n_max_tokens_overhead = 1;
std::vector<char> buf; std::vector<masked_locations> masks;
buf.resize(f.size);
f.read_raw(buf.data(), f.size); out_samples_begin.clear();
out_samples_begin.push_back(0);
std::vector<int> utf8_units; for (size_t sample_index = 0; sample_index < samples.size(); ++sample_index) {
std::vector<int> utf8_nunits; std::vector<char> buf_prompt;
utf8_units.resize(buf.size()); std::vector<int> utf8_units_prompt;
utf8_nunits.resize(buf.size()); std::vector<int> utf8_nunits_prompt;
mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size()); std::vector<llama_token> out_tokens_prompt;
if (sample_start.size() == 0) { std::vector<char> buf_response;
// tokenize all data at once std::vector<int> utf8_units_response;
out_tokens.resize(buf.size() + n_max_tokens_overhead); std::vector<int> utf8_nunits_response;
std::vector<llama_token> out_tokens_response;
int n_tokens = llama_tokenize( // Prepare prompt buffer
llama_get_model(lctx), buf_prompt.resize(samples[sample_index].prompt.size());
buf.data(), std::copy(samples[sample_index].prompt.begin(), samples[sample_index].prompt.end(), buf_prompt.begin());
(int) buf.size(), utf8_units_prompt.resize(buf_prompt.size());
out_tokens.data(), utf8_nunits_prompt.resize(buf_prompt.size());
(int) out_tokens.size(), mark_utf8_units(buf_prompt.data(), utf8_units_prompt.data(), utf8_nunits_prompt.data(), buf_prompt.size());
false, false);
if (n_tokens < 0) { // Prepare response buffer
out_tokens.resize(-n_tokens); buf_response.clear();
n_tokens = llama_tokenize( utf8_units_response.clear();
llama_get_model(lctx), utf8_nunits_response.clear();
buf.data(),
(int) buf.size(), buf_response.resize(samples[sample_index].response.size());
out_tokens.data(), std::copy(samples[sample_index].response.begin(), samples[sample_index].response.end(), buf_response.begin());
(int) out_tokens.size(), utf8_units_response.resize(buf_response.size());
false, false); utf8_nunits_response.resize(buf_response.size());
} mark_utf8_units(buf_response.data(), utf8_units_response.data(), utf8_nunits_response.data(), buf_response.size());
if (n_tokens >= 0) {
out_tokens.resize(n_tokens);
} out_tokens_prompt.resize(buf_prompt.size() + n_max_tokens_overhead);
out_tokens_response.resize(buf_response.size() + n_max_tokens_overhead);
int n_tokens_prompt = tokenize_buffer(lctx, buf_prompt, out_tokens_prompt);
int n_tokens_response = tokenize_buffer(lctx, buf_response, out_tokens_response);
int total_tokens = out_tokens_prompt.size() + out_tokens_response.size();
// generate sample starts at all token positions // generate sample starts at all token positions
out_samples_begin.clear();
out_samples_begin.push_back(0);
out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
out_samples_begin.push_back(sample_begin);
out_samples_size.push_back(context_length); out_samples_size.push_back(context_length);
out_samples_begin.push_back(sample_index * context_length);
// Concat and padding
out_tokens.insert(out_tokens.end(), out_tokens_prompt.begin(), out_tokens_prompt.end());
out_tokens.insert(out_tokens.end(), out_tokens_response.begin(), out_tokens_response.end());
out_tokens.resize((sample_index + 1) * context_length);
// Calculate mask pos
masked_locations ml;
ml.ir = sample_index * context_length;
ml.ires = sample_index * context_length + out_tokens_prompt.size();
ml.ipad = total_tokens > context_length ? (sample_index + 1) * context_length : sample_index * context_length + total_tokens;
masks.push_back(ml);
} }
} else { // Pop the last one
// split data into samples and tokenize each sample out_samples_begin.pop_back();
std::string data_str(buf.data(), buf.size());
out_samples_begin.clear();
out_samples_size.clear();
out_tokens.clear();
// find all positions of pattern sample_start
size_t sample_begin = data_str.find(sample_start, 0);
while (sample_begin != std::string::npos) {
out_samples_begin.push_back(sample_begin);
const size_t search_start = sample_begin + sample_start.size();
sample_begin = data_str.find(sample_start, search_start);
}
if (out_samples_begin.size() == 0) {
printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
__func__, sample_start.c_str());
out_samples_begin.push_back(0);
}
out_samples_size.resize(out_samples_begin.size(), 0);
std::vector<char> buf_sample;
std::vector<llama_token> tok_sample;
const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
size_t found_too_big_sample = 0;
size_t found_too_small_sample = 0;
size_t found_empty_sample = 0;
size_t found_min_sample_size = SIZE_MAX;
size_t found_max_sample_size = 0;
size_t max_token_text_size = 0;
int n_vocab = llama_n_vocab(llama_get_model(lctx));
for (llama_token token=0; token < n_vocab; ++token) {
max_token_text_size = std::max(
max_token_text_size,
strlen(llama_token_get_text(llama_get_model(lctx), token)));
}
// upper bound of context byte length.
// strings with this byte length should always tokenize to at least context_length tokens.
size_t context_byte_len = max_token_text_size*context_length;
for (unsigned i=0; i<out_samples_begin.size(); ++i) {
// determine sample begin and end from pattern positions
size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
size_t sample_end = overlapping_samples
? std::min(
data_str.size(),
sample_begin + context_byte_len)
: (i+1 < out_samples_begin.size()
? out_samples_begin[i+1]
: data_str.size());
if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
// sample end is in the middle of an utf8 character.
// advance sample_end to the begin of the next utf8 character.
sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
}
size_t sample_size = sample_end - sample_begin;
if (sample_size == 0) {
++found_empty_sample;
}
if (sample_size > 0) {
// llama_tokenize expects zero terminated string,
// copy sample into buffer and zero terminate it.
buf_sample.resize(sample_size);
memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
// printf("sample: '%s'\n", buf_sample.data());
// tokenize the sample
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
(int) tok_sample.size(),
false, false);
if (n_tokens < 0) {
tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
(int) tok_sample.size(),
false, false);
GGML_ASSERT(n_tokens >= 0);
}
GGML_ASSERT(n_tokens <= (int) tok_sample.size());
if ((size_t) n_tokens > context_length) {
++found_too_big_sample;
} else if ((size_t) n_tokens < context_length) {
++found_too_small_sample;
}
found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
// write out tokens, start and size of sample
// overwrite the string start position with the token start position
out_samples_begin[i] = out_tokens.size();
out_samples_size[i] = (size_t) n_tokens;
out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
} else {
out_samples_begin[i] = out_tokens.size();
out_samples_size[i] = 0;
}
}
if (found_too_big_sample > 0) {
printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
__func__, found_too_big_sample, found_max_sample_size, context_length);
}
if (found_too_small_sample > 0) {
printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
__func__, found_too_small_sample, found_min_sample_size, context_length);
}
if (found_empty_sample) {
printf("%s: warning: found %zu empty samples.\n",
__func__, found_empty_sample);
}
}
printf("%s: total number of samples: %zu\n", printf("%s: total number of samples: %zu\n",
__func__, out_samples_begin.size()); __func__, out_samples_begin.size());

24689
nlohmann/json.hpp Normal file

File diff suppressed because it is too large Load diff