add train option "--sample-random-offsets"

Use samples beginning at random offsets.
The offset is only applied to the first sample in each batch context window.
Together with "--fill-with-next-samples" this may help for training endless text generation.

For example given a dataset containing samples "abcd", "ABCD", "0123".
With context size of 8 and options "--fill-with-next-samples", "--no-separate-with-eos", "--no-separate-with-bos",
the context windows of batches could only be filled with "abcdABCD", "ABCDabcd", "0123abcd", etc.

With "--sample-random-offsets" it can also be filled with "23abcdAB", "bcd0123A", etc.
This commit is contained in:
xaedes 2023-09-17 14:37:41 +02:00
parent bf2ad65836
commit d1bb6fb349
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
4 changed files with 32 additions and 4 deletions

View file

@ -211,6 +211,7 @@ int64_t get_example_targets_batch(
struct ggml_tensor * tokens_input, struct ggml_tensor * tokens_input,
struct ggml_tensor * target_probs, struct ggml_tensor * target_probs,
int64_t example_id, int64_t example_id,
const size_t * samples_offs,
const size_t * samples_begin, const size_t * samples_begin,
const size_t * samples_size, const size_t * samples_size,
size_t samples_count, size_t samples_count,
@ -218,7 +219,8 @@ int64_t get_example_targets_batch(
size_t n_train_data, size_t n_train_data,
bool separate_with_eos, bool separate_with_eos,
bool separate_with_bos, bool separate_with_bos,
bool fill_with_next_samples bool fill_with_next_samples,
bool sample_random_offsets
) { ) {
GGML_ASSERT(samples_count > 0); GGML_ASSERT(samples_count > 0);
GGML_ASSERT(tokens_input->n_dims == 2); GGML_ASSERT(tokens_input->n_dims == 2);
@ -238,8 +240,8 @@ int64_t get_example_targets_batch(
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) { for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k); // printf("%s: batch %d\n", __func__, k);
size_t sample_offs = 0;
size_t sample_idx = (example_id + used_samples) % samples_count; size_t sample_idx = (example_id + used_samples) % samples_count;
size_t sample_offs = sample_random_offsets ? samples_offs[sample_idx] : 0;
size_t sample_begin = samples_begin[sample_idx]; size_t sample_begin = samples_begin[sample_idx];
size_t sample_size = samples_size[sample_idx]; size_t sample_size = samples_size[sample_idx];
++used_samples; ++used_samples;
@ -308,6 +310,7 @@ std::string mt19937_seed_to_state(unsigned seed) {
std::string shuffle_samples( std::string shuffle_samples(
const std::string & rng_state, const std::string & rng_state,
size_t * shuffled_offs,
size_t * shuffled_begins, size_t * shuffled_begins,
size_t * shuffled_sizes, size_t * shuffled_sizes,
const size_t * begins, const size_t * begins,
@ -335,6 +338,11 @@ std::string shuffle_samples(
}); });
} }
// create random offsets
for (unsigned i=0; i<count; ++i) {
shuffled_offs[i] = (size_t) ((sizes[idcs[i]] - 1) * ((double) rng() / (double) (rng.max()-1)));
}
// reorder begins and sizes by sorted indices // reorder begins and sizes by sorted indices
for (unsigned i=0; i<count; ++i) { for (unsigned i=0; i<count; ++i) {
shuffled_begins[i] = begins[idcs[i]]; shuffled_begins[i] = begins[idcs[i]];
@ -1048,6 +1056,7 @@ struct train_params_common get_default_train_params_common() {
params.fill_with_next_samples = false; params.fill_with_next_samples = false;
params.separate_with_eos = false; params.separate_with_eos = false;
params.separate_with_bos = true; params.separate_with_bos = true;
params.sample_random_offsets = false;
params.force_reshuffle = false; params.force_reshuffle = false;
params.opt_past = 0; params.opt_past = 0;
@ -1097,6 +1106,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : ""); fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
fprintf(stderr, " --sample-random-offsets Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : "");
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
fprintf(stderr, " --no-flash Don't use flash attention \n"); fprintf(stderr, " --no-flash Don't use flash attention \n");
fprintf(stderr, " --use-flash Use flash attention (default)\n"); fprintf(stderr, " --use-flash Use flash attention (default)\n");
@ -1221,6 +1231,8 @@ bool consume_common_train_arg(
params->separate_with_eos = false; params->separate_with_eos = false;
} else if (arg == "--no-separate-with-bos") { } else if (arg == "--no-separate-with-bos") {
params->separate_with_bos = false; params->separate_with_bos = false;
} else if (arg == "--sample-random-offsets") {
params->sample_random_offsets = true;
} else if (arg == "--force-reshuffle") { } else if (arg == "--force-reshuffle") {
params->force_reshuffle = true; params->force_reshuffle = true;
} else if (arg == "--no-flash") { } else if (arg == "--no-flash") {
@ -1433,6 +1445,7 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) {
data->tokens_input, data->tokens_input,
data->target_probs, data->target_probs,
train->shuffle_next_sample, train->shuffle_next_sample,
data->shuffled_samples_offs,
data->shuffled_samples_begin, data->shuffled_samples_begin,
data->shuffled_samples_size, data->shuffled_samples_size,
data->samples_count, data->samples_count,
@ -1440,7 +1453,8 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) {
data->tokens_size, data->tokens_size,
params->separate_with_eos, params->separate_with_eos,
params->separate_with_bos, params->separate_with_bos,
params->fill_with_next_samples); params->fill_with_next_samples,
params->sample_random_offsets);
train->train_samples += used_samples; train->train_samples += used_samples;
train->shuffle_next_sample += used_samples; train->shuffle_next_sample += used_samples;
@ -1452,6 +1466,7 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) {
train->shuffle_rng_state_current = train->shuffle_rng_state_next; train->shuffle_rng_state_current = train->shuffle_rng_state_next;
train->shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_next = shuffle_samples(
train->shuffle_rng_state_current, train->shuffle_rng_state_current,
data->shuffled_samples_offs,
data->shuffled_samples_begin, data->shuffled_samples_begin,
data->shuffled_samples_size, data->shuffled_samples_size,
data->samples_begin, data->samples_begin,

View file

@ -56,6 +56,7 @@ struct train_params_common {
bool fill_with_next_samples; bool fill_with_next_samples;
bool separate_with_eos; bool separate_with_eos;
bool separate_with_bos; bool separate_with_bos;
bool sample_random_offsets;
bool force_reshuffle; bool force_reshuffle;
@ -93,6 +94,7 @@ struct train_opt_callback_data {
size_t tokens_size; size_t tokens_size;
size_t * samples_begin; size_t * samples_begin;
size_t * samples_size; size_t * samples_size;
size_t * shuffled_samples_offs;
size_t * shuffled_samples_begin; size_t * shuffled_samples_begin;
size_t * shuffled_samples_size; size_t * shuffled_samples_size;
size_t samples_count; size_t samples_count;
@ -153,6 +155,7 @@ int64_t get_example_targets_batch(
struct ggml_tensor * tokens_input, struct ggml_tensor * tokens_input,
struct ggml_tensor * target_probs, struct ggml_tensor * target_probs,
int64_t example_id, int64_t example_id,
const size_t * samples_offs,
const size_t * samples_begin, const size_t * samples_begin,
const size_t * samples_size, const size_t * samples_size,
size_t samples_count, size_t samples_count,
@ -160,7 +163,8 @@ int64_t get_example_targets_batch(
size_t n_train_data, size_t n_train_data,
bool separate_with_eos, bool separate_with_eos,
bool separate_with_bos, bool separate_with_bos,
bool fill_with_next_samples); bool fill_with_next_samples,
bool sample_random_offsets);
void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state); void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state);
@ -169,6 +173,7 @@ mt19937_state mt19937_seed_to_state(unsigned seed);
mt19937_state shuffle_samples( mt19937_state shuffle_samples(
const mt19937_state & rng_state, const mt19937_state & rng_state,
size_t * shuffled_offs,
size_t * shuffled_begins, size_t * shuffled_begins,
size_t * shuffled_sizes, size_t * shuffled_sizes,
const size_t * begins, const size_t * begins,

View file

@ -1876,12 +1876,15 @@ int main(int argc, char ** argv) {
train->shuffle_next_sample = 0; train->shuffle_next_sample = 0;
train->shuffle_samples_hash = shuffle_samples_hash; train->shuffle_samples_hash = shuffle_samples_hash;
} }
std::vector<size_t> train_shuffled_samples_offs;
std::vector<size_t> train_shuffled_samples_begin; std::vector<size_t> train_shuffled_samples_begin;
std::vector<size_t> train_shuffled_samples_size; std::vector<size_t> train_shuffled_samples_size;
train_shuffled_samples_offs.resize(train_samples_begin.size());
train_shuffled_samples_begin.resize(train_samples_begin.size()); train_shuffled_samples_begin.resize(train_samples_begin.size());
train_shuffled_samples_size.resize(train_samples_size.size()); train_shuffled_samples_size.resize(train_samples_size.size());
train->shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_next = shuffle_samples(
train->shuffle_rng_state_current, train->shuffle_rng_state_current,
train_shuffled_samples_offs.data(),
train_shuffled_samples_begin.data(), train_shuffled_samples_begin.data(),
train_shuffled_samples_size.data(), train_shuffled_samples_size.data(),
train_samples_begin.data(), train_samples_begin.data(),
@ -1909,6 +1912,7 @@ int main(int argc, char ** argv) {
opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.tokens_size = train_tokens.size();
opt_cb_data.samples_begin = train_samples_begin.data(); opt_cb_data.samples_begin = train_samples_begin.data();
opt_cb_data.samples_size = train_samples_size.data(); opt_cb_data.samples_size = train_samples_size.data();
opt_cb_data.shuffled_samples_offs = train_shuffled_samples_offs.data();
opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data(); opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data();
opt_cb_data.samples_count = train_samples_size.size(); opt_cb_data.samples_count = train_samples_size.size();

View file

@ -1059,12 +1059,15 @@ int main(int argc, char ** argv) {
train->shuffle_next_sample = 0; train->shuffle_next_sample = 0;
train->shuffle_samples_hash = shuffle_samples_hash; train->shuffle_samples_hash = shuffle_samples_hash;
} }
std::vector<size_t> train_shuffled_samples_offs;
std::vector<size_t> train_shuffled_samples_begin; std::vector<size_t> train_shuffled_samples_begin;
std::vector<size_t> train_shuffled_samples_size; std::vector<size_t> train_shuffled_samples_size;
train_shuffled_samples_offs.resize(train_samples_begin.size());
train_shuffled_samples_begin.resize(train_samples_begin.size()); train_shuffled_samples_begin.resize(train_samples_begin.size());
train_shuffled_samples_size.resize(train_samples_size.size()); train_shuffled_samples_size.resize(train_samples_size.size());
train->shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_next = shuffle_samples(
train->shuffle_rng_state_current, train->shuffle_rng_state_current,
train_shuffled_samples_offs.data(),
train_shuffled_samples_begin.data(), train_shuffled_samples_begin.data(),
train_shuffled_samples_size.data(), train_shuffled_samples_size.data(),
train_samples_begin.data(), train_samples_begin.data(),
@ -1091,6 +1094,7 @@ int main(int argc, char ** argv) {
opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.tokens_size = train_tokens.size();
opt_cb_data.samples_begin = train_samples_begin.data(); opt_cb_data.samples_begin = train_samples_begin.data();
opt_cb_data.samples_size = train_samples_size.data(); opt_cb_data.samples_size = train_samples_size.data();
opt_cb_data.shuffled_samples_offs = train_shuffled_samples_offs.data();
opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data(); opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data();
opt_cb_data.samples_count = train_samples_size.size(); opt_cb_data.samples_count = train_samples_size.size();