sft
This commit is contained in:
parent
629f917cd6
commit
66b24d9f31
2 changed files with 24808 additions and 170 deletions
291
common/train.cpp
291
common/train.cpp
|
@ -5,6 +5,13 @@
|
|||
#include <sstream>
|
||||
#include <functional>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
|
||||
struct random_normal_distribution {
|
||||
std::mt19937 gen;
|
||||
std::normal_distribution<float> rd;
|
||||
|
@ -706,6 +713,65 @@ void save_train_state_gguf(struct gguf_context * fctx, struct train_state * trai
|
|||
save_opt_context_gguf(fctx, train->opt);
|
||||
}
|
||||
|
||||
struct masked_locations {
|
||||
// 数据的index,范围是[0, batch_size * seq_length-1]
|
||||
int ir;
|
||||
// response开始的位置,会从这个位置开始计算loss,范围是[0, context_length-1]
|
||||
int ires;
|
||||
// 这个位置开始都是padding token,会在这个位置前结束计算loss(这个位置不计算loss),范围是[0, context_length-1]
|
||||
int ipad;
|
||||
};
|
||||
|
||||
struct Data {
|
||||
std::string prompt;
|
||||
std::string response;
|
||||
// std::vector<int> mask;
|
||||
};
|
||||
|
||||
std::vector<Data> read_json(std::string filename) {
|
||||
std::vector<Data> data;
|
||||
std::ifstream file(filename);
|
||||
if (file.is_open()) {
|
||||
json j;
|
||||
file >> j;
|
||||
for (auto& item : j) {
|
||||
Data d;
|
||||
std::string prompt = item["prompt"];
|
||||
std::string response = item["response"];
|
||||
d.prompt = prompt;
|
||||
d.response = response;
|
||||
data.push_back(d);
|
||||
}
|
||||
file.close();
|
||||
} else {
|
||||
std::cout << "Unable to open file" << filename << std::endl;
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
int tokenize_buffer(llama_context* lctx, const std::vector<char>& buffer, std::vector<int>& out_tokens) {
|
||||
int n_tokens = llama_tokenize(
|
||||
llama_get_model(lctx),
|
||||
buffer.data(),
|
||||
(int) buffer.size(),
|
||||
out_tokens.data(),
|
||||
(int) out_tokens.size(),
|
||||
false, false);
|
||||
if (n_tokens < 0) {
|
||||
out_tokens.resize(-n_tokens);
|
||||
n_tokens = llama_tokenize(
|
||||
llama_get_model(lctx),
|
||||
buffer.data(),
|
||||
(int) buffer.size(),
|
||||
out_tokens.data(),
|
||||
(int) out_tokens.size(),
|
||||
false, false);
|
||||
}
|
||||
if (n_tokens >= 0) {
|
||||
out_tokens.resize(n_tokens);
|
||||
}
|
||||
return n_tokens;
|
||||
}
|
||||
|
||||
struct llama_file {
|
||||
// use FILE * so we don't have to re-open the file to mmap
|
||||
|
@ -827,192 +893,75 @@ size_t tokenize_file(
|
|||
std::vector<llama_token> & out_tokens,
|
||||
std::vector<size_t> & out_samples_begin,
|
||||
std::vector<size_t> & out_samples_size) {
|
||||
struct llama_file f(filename, "rb");
|
||||
|
||||
if (f.size == 0) {
|
||||
out_tokens.clear();
|
||||
out_samples_begin.clear();
|
||||
out_samples_size.clear();
|
||||
printf("%s: warning: empty or not existing training data file '%s'\n",
|
||||
__func__, filename);
|
||||
return out_tokens.size();
|
||||
}
|
||||
std::vector<Data> samples = read_json(filename);
|
||||
|
||||
// account for possible leading whitespace that will be added by tokenizer
|
||||
// e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
|
||||
const int n_max_tokens_overhead = 1;
|
||||
|
||||
std::vector<char> buf;
|
||||
buf.resize(f.size);
|
||||
std::vector<masked_locations> masks;
|
||||
|
||||
f.read_raw(buf.data(), f.size);
|
||||
out_samples_begin.clear();
|
||||
out_samples_begin.push_back(0);
|
||||
|
||||
std::vector<int> utf8_units;
|
||||
std::vector<int> utf8_nunits;
|
||||
utf8_units.resize(buf.size());
|
||||
utf8_nunits.resize(buf.size());
|
||||
mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
|
||||
for (size_t sample_index = 0; sample_index < samples.size(); ++sample_index) {
|
||||
std::vector<char> buf_prompt;
|
||||
std::vector<int> utf8_units_prompt;
|
||||
std::vector<int> utf8_nunits_prompt;
|
||||
std::vector<llama_token> out_tokens_prompt;
|
||||
|
||||
if (sample_start.size() == 0) {
|
||||
// tokenize all data at once
|
||||
out_tokens.resize(buf.size() + n_max_tokens_overhead);
|
||||
std::vector<char> buf_response;
|
||||
std::vector<int> utf8_units_response;
|
||||
std::vector<int> utf8_nunits_response;
|
||||
std::vector<llama_token> out_tokens_response;
|
||||
|
||||
int n_tokens = llama_tokenize(
|
||||
llama_get_model(lctx),
|
||||
buf.data(),
|
||||
(int) buf.size(),
|
||||
out_tokens.data(),
|
||||
(int) out_tokens.size(),
|
||||
false, false);
|
||||
if (n_tokens < 0) {
|
||||
out_tokens.resize(-n_tokens);
|
||||
n_tokens = llama_tokenize(
|
||||
llama_get_model(lctx),
|
||||
buf.data(),
|
||||
(int) buf.size(),
|
||||
out_tokens.data(),
|
||||
(int) out_tokens.size(),
|
||||
false, false);
|
||||
}
|
||||
if (n_tokens >= 0) {
|
||||
out_tokens.resize(n_tokens);
|
||||
}
|
||||
// Prepare prompt buffer
|
||||
buf_prompt.resize(samples[sample_index].prompt.size());
|
||||
std::copy(samples[sample_index].prompt.begin(), samples[sample_index].prompt.end(), buf_prompt.begin());
|
||||
utf8_units_prompt.resize(buf_prompt.size());
|
||||
utf8_nunits_prompt.resize(buf_prompt.size());
|
||||
mark_utf8_units(buf_prompt.data(), utf8_units_prompt.data(), utf8_nunits_prompt.data(), buf_prompt.size());
|
||||
|
||||
// Prepare response buffer
|
||||
buf_response.clear();
|
||||
utf8_units_response.clear();
|
||||
utf8_nunits_response.clear();
|
||||
|
||||
buf_response.resize(samples[sample_index].response.size());
|
||||
std::copy(samples[sample_index].response.begin(), samples[sample_index].response.end(), buf_response.begin());
|
||||
utf8_units_response.resize(buf_response.size());
|
||||
utf8_nunits_response.resize(buf_response.size());
|
||||
mark_utf8_units(buf_response.data(), utf8_units_response.data(), utf8_nunits_response.data(), buf_response.size());
|
||||
|
||||
|
||||
out_tokens_prompt.resize(buf_prompt.size() + n_max_tokens_overhead);
|
||||
out_tokens_response.resize(buf_response.size() + n_max_tokens_overhead);
|
||||
|
||||
int n_tokens_prompt = tokenize_buffer(lctx, buf_prompt, out_tokens_prompt);
|
||||
int n_tokens_response = tokenize_buffer(lctx, buf_response, out_tokens_response);
|
||||
int total_tokens = out_tokens_prompt.size() + out_tokens_response.size();
|
||||
|
||||
// generate sample starts at all token positions
|
||||
out_samples_begin.clear();
|
||||
out_samples_begin.push_back(0);
|
||||
out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
|
||||
size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
|
||||
for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
|
||||
out_samples_begin.push_back(sample_begin);
|
||||
out_samples_size.push_back(context_length);
|
||||
out_samples_begin.push_back(sample_index * context_length);
|
||||
|
||||
// Concat and padding
|
||||
out_tokens.insert(out_tokens.end(), out_tokens_prompt.begin(), out_tokens_prompt.end());
|
||||
out_tokens.insert(out_tokens.end(), out_tokens_response.begin(), out_tokens_response.end());
|
||||
out_tokens.resize((sample_index + 1) * context_length);
|
||||
|
||||
// Calculate mask pos
|
||||
masked_locations ml;
|
||||
ml.ir = sample_index * context_length;
|
||||
ml.ires = sample_index * context_length + out_tokens_prompt.size();
|
||||
ml.ipad = total_tokens > context_length ? (sample_index + 1) * context_length : sample_index * context_length + total_tokens;
|
||||
|
||||
masks.push_back(ml);
|
||||
}
|
||||
} else {
|
||||
// split data into samples and tokenize each sample
|
||||
std::string data_str(buf.data(), buf.size());
|
||||
out_samples_begin.clear();
|
||||
out_samples_size.clear();
|
||||
out_tokens.clear();
|
||||
// Pop the last one
|
||||
out_samples_begin.pop_back();
|
||||
|
||||
// find all positions of pattern sample_start
|
||||
size_t sample_begin = data_str.find(sample_start, 0);
|
||||
while (sample_begin != std::string::npos) {
|
||||
out_samples_begin.push_back(sample_begin);
|
||||
const size_t search_start = sample_begin + sample_start.size();
|
||||
sample_begin = data_str.find(sample_start, search_start);
|
||||
}
|
||||
if (out_samples_begin.size() == 0) {
|
||||
printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
|
||||
__func__, sample_start.c_str());
|
||||
out_samples_begin.push_back(0);
|
||||
}
|
||||
|
||||
out_samples_size.resize(out_samples_begin.size(), 0);
|
||||
|
||||
std::vector<char> buf_sample;
|
||||
std::vector<llama_token> tok_sample;
|
||||
|
||||
const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
|
||||
size_t found_too_big_sample = 0;
|
||||
size_t found_too_small_sample = 0;
|
||||
size_t found_empty_sample = 0;
|
||||
size_t found_min_sample_size = SIZE_MAX;
|
||||
size_t found_max_sample_size = 0;
|
||||
|
||||
size_t max_token_text_size = 0;
|
||||
int n_vocab = llama_n_vocab(llama_get_model(lctx));
|
||||
for (llama_token token=0; token < n_vocab; ++token) {
|
||||
max_token_text_size = std::max(
|
||||
max_token_text_size,
|
||||
strlen(llama_token_get_text(llama_get_model(lctx), token)));
|
||||
}
|
||||
|
||||
// upper bound of context byte length.
|
||||
// strings with this byte length should always tokenize to at least context_length tokens.
|
||||
size_t context_byte_len = max_token_text_size*context_length;
|
||||
|
||||
for (unsigned i=0; i<out_samples_begin.size(); ++i) {
|
||||
// determine sample begin and end from pattern positions
|
||||
size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
|
||||
size_t sample_end = overlapping_samples
|
||||
? std::min(
|
||||
data_str.size(),
|
||||
sample_begin + context_byte_len)
|
||||
: (i+1 < out_samples_begin.size()
|
||||
? out_samples_begin[i+1]
|
||||
: data_str.size());
|
||||
if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
|
||||
// sample end is in the middle of an utf8 character.
|
||||
// advance sample_end to the begin of the next utf8 character.
|
||||
sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
|
||||
}
|
||||
size_t sample_size = sample_end - sample_begin;
|
||||
if (sample_size == 0) {
|
||||
++found_empty_sample;
|
||||
}
|
||||
|
||||
if (sample_size > 0) {
|
||||
// llama_tokenize expects zero terminated string,
|
||||
// copy sample into buffer and zero terminate it.
|
||||
buf_sample.resize(sample_size);
|
||||
memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
|
||||
|
||||
// printf("sample: '%s'\n", buf_sample.data());
|
||||
|
||||
// tokenize the sample
|
||||
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
|
||||
int n_tokens = llama_tokenize(llama_get_model(lctx),
|
||||
buf_sample.data(),
|
||||
(int) buf_sample.size(),
|
||||
tok_sample.data(),
|
||||
(int) tok_sample.size(),
|
||||
false, false);
|
||||
if (n_tokens < 0) {
|
||||
tok_sample.resize(-n_tokens);
|
||||
n_tokens = llama_tokenize(llama_get_model(lctx),
|
||||
buf_sample.data(),
|
||||
(int) buf_sample.size(),
|
||||
tok_sample.data(),
|
||||
(int) tok_sample.size(),
|
||||
false, false);
|
||||
GGML_ASSERT(n_tokens >= 0);
|
||||
}
|
||||
GGML_ASSERT(n_tokens <= (int) tok_sample.size());
|
||||
|
||||
if ((size_t) n_tokens > context_length) {
|
||||
++found_too_big_sample;
|
||||
} else if ((size_t) n_tokens < context_length) {
|
||||
++found_too_small_sample;
|
||||
}
|
||||
found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
|
||||
found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
|
||||
|
||||
// write out tokens, start and size of sample
|
||||
// overwrite the string start position with the token start position
|
||||
out_samples_begin[i] = out_tokens.size();
|
||||
out_samples_size[i] = (size_t) n_tokens;
|
||||
out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
|
||||
} else {
|
||||
out_samples_begin[i] = out_tokens.size();
|
||||
out_samples_size[i] = 0;
|
||||
}
|
||||
|
||||
}
|
||||
if (found_too_big_sample > 0) {
|
||||
printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
|
||||
__func__, found_too_big_sample, found_max_sample_size, context_length);
|
||||
}
|
||||
|
||||
if (found_too_small_sample > 0) {
|
||||
printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
|
||||
__func__, found_too_small_sample, found_min_sample_size, context_length);
|
||||
}
|
||||
|
||||
if (found_empty_sample) {
|
||||
printf("%s: warning: found %zu empty samples.\n",
|
||||
__func__, found_empty_sample);
|
||||
}
|
||||
}
|
||||
printf("%s: total number of samples: %zu\n",
|
||||
__func__, out_samples_begin.size());
|
||||
|
||||
|
|
24689
nlohmann/json.hpp
Normal file
24689
nlohmann/json.hpp
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue