reshuffle original sample order instead of the previous shuffled order

otherwise resumed reshuffle will not result in same sample order
This commit is contained in:
xaedes 2023-09-14 18:21:23 +02:00
parent 3a9c1d7f5a
commit 0971fee710
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1439,7 +1439,13 @@ std::string mt19937_seed_to_state(unsigned seed) {
return mt19937_get_state(rng); return mt19937_get_state(rng);
} }
std::string shuffle_samples(const std::string& rng_state, size_t * begins, size_t * sizes, size_t count) { std::string shuffle_samples(
const std::string & rng_state,
const size_t * begins,
const size_t * sizes,
size_t * shuffled_begins,
size_t * shuffled_sizes,
size_t count) {
if (count == 0) return rng_state; if (count == 0) return rng_state;
std::mt19937 rng; std::mt19937 rng;
@ -1463,18 +1469,13 @@ std::string shuffle_samples(const std::string& rng_state, size_t * begins, size_
} }
// reorder begins and sizes by sorted indices // reorder begins and sizes by sorted indices
std::vector<size_t> reordered; for (unsigned i=0; i<count; ++i) {
reordered.resize(count); shuffled_begins[i] = begins[idcs[i]];
}
for (unsigned i=0; i<count; ++i) { for (unsigned i=0; i<count; ++i) {
reordered[i] = begins[idcs[i]]; shuffled_sizes[i] = sizes[idcs[i]];
} }
memcpy(begins, reordered.data(), sizeof(*begins)*reordered.size());
for (unsigned i=0; i<count; ++i) {
reordered[i] = sizes[idcs[i]];
}
memcpy(sizes, reordered.data(), sizeof(*sizes)*reordered.size());
return mt19937_get_state(rng); return mt19937_get_state(rng);
} }
@ -2672,6 +2673,8 @@ struct opt_callback_data {
size_t tokens_size; size_t tokens_size;
size_t * samples_begin; size_t * samples_begin;
size_t * samples_size; size_t * samples_size;
size_t * shuffled_samples_begin;
size_t * shuffled_samples_size;
size_t samples_count; size_t samples_count;
struct ggml_tensor * tokens_input; struct ggml_tensor * tokens_input;
struct ggml_tensor * target_probs; struct ggml_tensor * target_probs;
@ -2793,8 +2796,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
int used_samples = get_example_targets_batch( int used_samples = get_example_targets_batch(
data->lctx, data->lctx,
data->samples_begin, data->shuffled_samples_begin,
data->samples_size, data->shuffled_samples_size,
data->samples_count, data->samples_count,
data->tokens_data, data->tokens_data,
data->tokens_size, data->tokens_size,
@ -2816,6 +2819,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
data->lora->shuffle_rng_state_current, data->lora->shuffle_rng_state_current,
data->samples_begin, data->samples_begin,
data->samples_size, data->samples_size,
data->shuffled_samples_begin,
data->shuffled_samples_size,
data->samples_count); data->samples_count);
data->lora->shuffle_next_sample = 0; data->lora->shuffle_next_sample = 0;
} }
@ -3196,10 +3201,16 @@ int main(int argc, char ** argv) {
lora.shuffle_next_sample = 0; lora.shuffle_next_sample = 0;
lora.shuffle_samples_hash = shuffle_samples_hash; lora.shuffle_samples_hash = shuffle_samples_hash;
} }
std::vector<size_t> train_shuffled_samples_begin;
std::vector<size_t> train_shuffled_samples_size;
train_shuffled_samples_begin.resize(train_samples_begin.size());
train_shuffled_samples_size.resize(train_samples_size.size());
lora.shuffle_rng_state_next = shuffle_samples( lora.shuffle_rng_state_next = shuffle_samples(
lora.shuffle_rng_state_current, lora.shuffle_rng_state_current,
train_samples_begin.data(), train_samples_begin.data(),
train_samples_size.data(), train_samples_size.data(),
train_shuffled_samples_begin.data(),
train_shuffled_samples_size.data(),
train_samples_size.size()); train_samples_size.size());
printf("%s: begin training\n", __func__); printf("%s: begin training\n", __func__);
@ -3215,6 +3226,8 @@ int main(int argc, char ** argv) {
opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.tokens_size = train_tokens.size();
opt_cb_data.samples_begin = train_samples_begin.data(); opt_cb_data.samples_begin = train_samples_begin.data();
opt_cb_data.samples_size = train_samples_size.data(); opt_cb_data.samples_size = train_samples_size.data();
opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data();
opt_cb_data.samples_count = train_samples_size.size(); opt_cb_data.samples_count = train_samples_size.size();
opt_cb_data.tokens_input = tokens_input; opt_cb_data.tokens_input = tokens_input;
opt_cb_data.target_probs = target_probs; opt_cb_data.target_probs = target_probs;