From d1bb6fb3499efefac9e1eb7bab0cc1fdf08e66b3 Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 17 Sep 2023 14:37:41 +0200 Subject: [PATCH] add train option "--sample-random-offsets" Use samples beginning at random offsets. The offset is only applied to the first sample in each batch context window. Together with "--fill-with-next-samples" this may help for training endless text generation. For example given a dataset containing samples "abcd", "ABCD", "0123". With context size of 8 and options "--fill-with-next-samples", "--no-separate-with-eos", "--no-separate-with-bos", the context windows of batches could only be filled with "abcdABCD", "ABCDabcd", "0123abcd", etc. With "--sample-random-offsets" it can also be filled with "23abcdAB", "bcd0123A", etc. --- common/train.cpp | 21 ++++++++++++++++--- common/train.h | 7 ++++++- examples/finetune/finetune.cpp | 4 ++++ .../train-text-from-scratch.cpp | 4 ++++ 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/common/train.cpp b/common/train.cpp index 991679292..10e0107eb 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -211,6 +211,7 @@ int64_t get_example_targets_batch( struct ggml_tensor * tokens_input, struct ggml_tensor * target_probs, int64_t example_id, + const size_t * samples_offs, const size_t * samples_begin, const size_t * samples_size, size_t samples_count, @@ -218,7 +219,8 @@ int64_t get_example_targets_batch( size_t n_train_data, bool separate_with_eos, bool separate_with_bos, - bool fill_with_next_samples + bool fill_with_next_samples, + bool sample_random_offsets ) { GGML_ASSERT(samples_count > 0); GGML_ASSERT(tokens_input->n_dims == 2); @@ -238,8 +240,8 @@ int64_t get_example_targets_batch( // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); for (int k=0; kseparate_with_bos ? " (default)" : ""); fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); + fprintf(stderr, " --sample-random-offsets Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : ""); fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); fprintf(stderr, " --no-flash Don't use flash attention \n"); fprintf(stderr, " --use-flash Use flash attention (default)\n"); @@ -1221,6 +1231,8 @@ bool consume_common_train_arg( params->separate_with_eos = false; } else if (arg == "--no-separate-with-bos") { params->separate_with_bos = false; + } else if (arg == "--sample-random-offsets") { + params->sample_random_offsets = true; } else if (arg == "--force-reshuffle") { params->force_reshuffle = true; } else if (arg == "--no-flash") { @@ -1433,6 +1445,7 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) { data->tokens_input, data->target_probs, train->shuffle_next_sample, + data->shuffled_samples_offs, data->shuffled_samples_begin, data->shuffled_samples_size, data->samples_count, @@ -1440,7 +1453,8 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) { data->tokens_size, params->separate_with_eos, params->separate_with_bos, - params->fill_with_next_samples); + params->fill_with_next_samples, + params->sample_random_offsets); train->train_samples += used_samples; train->shuffle_next_sample += used_samples; @@ -1452,6 +1466,7 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) { train->shuffle_rng_state_current = train->shuffle_rng_state_next; train->shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_current, + data->shuffled_samples_offs, data->shuffled_samples_begin, data->shuffled_samples_size, data->samples_begin, diff --git a/common/train.h b/common/train.h index 4857ba922..6ef1f9fc5 100644 --- a/common/train.h +++ b/common/train.h @@ -56,6 +56,7 @@ struct train_params_common { bool fill_with_next_samples; bool separate_with_eos; bool separate_with_bos; + bool sample_random_offsets; bool force_reshuffle; @@ -93,6 +94,7 @@ struct train_opt_callback_data { size_t tokens_size; size_t * samples_begin; size_t * samples_size; + size_t * shuffled_samples_offs; size_t * shuffled_samples_begin; size_t * shuffled_samples_size; size_t samples_count; @@ -153,6 +155,7 @@ int64_t get_example_targets_batch( struct ggml_tensor * tokens_input, struct ggml_tensor * target_probs, int64_t example_id, + const size_t * samples_offs, const size_t * samples_begin, const size_t * samples_size, size_t samples_count, @@ -160,7 +163,8 @@ int64_t get_example_targets_batch( size_t n_train_data, bool separate_with_eos, bool separate_with_bos, - bool fill_with_next_samples); + bool fill_with_next_samples, + bool sample_random_offsets); void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state); @@ -169,6 +173,7 @@ mt19937_state mt19937_seed_to_state(unsigned seed); mt19937_state shuffle_samples( const mt19937_state & rng_state, + size_t * shuffled_offs, size_t * shuffled_begins, size_t * shuffled_sizes, const size_t * begins, diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index ea1a68b0d..e631451a5 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1876,12 +1876,15 @@ int main(int argc, char ** argv) { train->shuffle_next_sample = 0; train->shuffle_samples_hash = shuffle_samples_hash; } + std::vector train_shuffled_samples_offs; std::vector train_shuffled_samples_begin; std::vector train_shuffled_samples_size; + train_shuffled_samples_offs.resize(train_samples_begin.size()); train_shuffled_samples_begin.resize(train_samples_begin.size()); train_shuffled_samples_size.resize(train_samples_size.size()); train->shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_current, + train_shuffled_samples_offs.data(), train_shuffled_samples_begin.data(), train_shuffled_samples_size.data(), train_samples_begin.data(), @@ -1909,6 +1912,7 @@ int main(int argc, char ** argv) { opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.samples_begin = train_samples_begin.data(); opt_cb_data.samples_size = train_samples_size.data(); + opt_cb_data.shuffled_samples_offs = train_shuffled_samples_offs.data(); opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data(); opt_cb_data.samples_count = train_samples_size.size(); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index d5cf42665..0da7ec11b 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1059,12 +1059,15 @@ int main(int argc, char ** argv) { train->shuffle_next_sample = 0; train->shuffle_samples_hash = shuffle_samples_hash; } + std::vector train_shuffled_samples_offs; std::vector train_shuffled_samples_begin; std::vector train_shuffled_samples_size; + train_shuffled_samples_offs.resize(train_samples_begin.size()); train_shuffled_samples_begin.resize(train_samples_begin.size()); train_shuffled_samples_size.resize(train_samples_size.size()); train->shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_current, + train_shuffled_samples_offs.data(), train_shuffled_samples_begin.data(), train_shuffled_samples_size.data(), train_samples_begin.data(), @@ -1091,6 +1094,7 @@ int main(int argc, char ** argv) { opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.samples_begin = train_samples_begin.data(); opt_cb_data.samples_size = train_samples_size.data(); + opt_cb_data.shuffled_samples_offs = train_shuffled_samples_offs.data(); opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data(); opt_cb_data.samples_count = train_samples_size.size();