From 0971fee710fbd881ca2688b7cbdc9feac723ace1 Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 14 Sep 2023 18:21:23 +0200 Subject: [PATCH] reshuffle original sample order instead of the previous shuffled order otherwise resumed reshuffle will not result in same sample order --- examples/finetune/finetune.cpp | 37 +++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 260b89ce3..37be889d3 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1439,7 +1439,13 @@ std::string mt19937_seed_to_state(unsigned seed) { return mt19937_get_state(rng); } -std::string shuffle_samples(const std::string& rng_state, size_t * begins, size_t * sizes, size_t count) { +std::string shuffle_samples( + const std::string & rng_state, + const size_t * begins, + const size_t * sizes, + size_t * shuffled_begins, + size_t * shuffled_sizes, + size_t count) { if (count == 0) return rng_state; std::mt19937 rng; @@ -1463,18 +1469,13 @@ std::string shuffle_samples(const std::string& rng_state, size_t * begins, size_ } // reorder begins and sizes by sorted indices - std::vector reordered; - reordered.resize(count); + for (unsigned i=0; ilctx, - data->samples_begin, - data->samples_size, + data->shuffled_samples_begin, + data->shuffled_samples_size, data->samples_count, data->tokens_data, data->tokens_size, @@ -2816,6 +2819,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) { data->lora->shuffle_rng_state_current, data->samples_begin, data->samples_size, + data->shuffled_samples_begin, + data->shuffled_samples_size, data->samples_count); data->lora->shuffle_next_sample = 0; } @@ -3196,10 +3201,16 @@ int main(int argc, char ** argv) { lora.shuffle_next_sample = 0; lora.shuffle_samples_hash = shuffle_samples_hash; } + std::vector train_shuffled_samples_begin; + std::vector train_shuffled_samples_size; + train_shuffled_samples_begin.resize(train_samples_begin.size()); + train_shuffled_samples_size.resize(train_samples_size.size()); lora.shuffle_rng_state_next = shuffle_samples( lora.shuffle_rng_state_current, train_samples_begin.data(), train_samples_size.data(), + train_shuffled_samples_begin.data(), + train_shuffled_samples_size.data(), train_samples_size.size()); printf("%s: begin training\n", __func__); @@ -3215,6 +3226,8 @@ int main(int argc, char ** argv) { opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.samples_begin = train_samples_begin.data(); opt_cb_data.samples_size = train_samples_size.data(); + opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); + opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data(); opt_cb_data.samples_count = train_samples_size.size(); opt_cb_data.tokens_input = tokens_input; opt_cb_data.target_probs = target_probs;