reshuffle original sample order instead of the previous shuffled order

otherwise resumed reshuffle will not result in same sample order
This commit is contained in:
xaedes 2023-09-14 18:21:23 +02:00
parent 3a9c1d7f5a
commit 0971fee710
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1439,7 +1439,13 @@ std::string mt19937_seed_to_state(unsigned seed) {
return mt19937_get_state(rng);
}
std::string shuffle_samples(const std::string& rng_state, size_t * begins, size_t * sizes, size_t count) {
std::string shuffle_samples(
const std::string & rng_state,
const size_t * begins,
const size_t * sizes,
size_t * shuffled_begins,
size_t * shuffled_sizes,
size_t count) {
if (count == 0) return rng_state;
std::mt19937 rng;
@ -1463,18 +1469,13 @@ std::string shuffle_samples(const std::string& rng_state, size_t * begins, size_
}
// reorder begins and sizes by sorted indices
std::vector<size_t> reordered;
reordered.resize(count);
for (unsigned i=0; i<count; ++i) {
shuffled_begins[i] = begins[idcs[i]];
}
for (unsigned i=0; i<count; ++i) {
reordered[i] = begins[idcs[i]];
shuffled_sizes[i] = sizes[idcs[i]];
}
memcpy(begins, reordered.data(), sizeof(*begins)*reordered.size());
for (unsigned i=0; i<count; ++i) {
reordered[i] = sizes[idcs[i]];
}
memcpy(sizes, reordered.data(), sizeof(*sizes)*reordered.size());
return mt19937_get_state(rng);
}
@ -2672,6 +2673,8 @@ struct opt_callback_data {
size_t tokens_size;
size_t * samples_begin;
size_t * samples_size;
size_t * shuffled_samples_begin;
size_t * shuffled_samples_size;
size_t samples_count;
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_probs;
@ -2793,8 +2796,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
int used_samples = get_example_targets_batch(
data->lctx,
data->samples_begin,
data->samples_size,
data->shuffled_samples_begin,
data->shuffled_samples_size,
data->samples_count,
data->tokens_data,
data->tokens_size,
@ -2816,6 +2819,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
data->lora->shuffle_rng_state_current,
data->samples_begin,
data->samples_size,
data->shuffled_samples_begin,
data->shuffled_samples_size,
data->samples_count);
data->lora->shuffle_next_sample = 0;
}
@ -3196,10 +3201,16 @@ int main(int argc, char ** argv) {
lora.shuffle_next_sample = 0;
lora.shuffle_samples_hash = shuffle_samples_hash;
}
std::vector<size_t> train_shuffled_samples_begin;
std::vector<size_t> train_shuffled_samples_size;
train_shuffled_samples_begin.resize(train_samples_begin.size());
train_shuffled_samples_size.resize(train_samples_size.size());
lora.shuffle_rng_state_next = shuffle_samples(
lora.shuffle_rng_state_current,
train_samples_begin.data(),
train_samples_size.data(),
train_shuffled_samples_begin.data(),
train_shuffled_samples_size.data(),
train_samples_size.size());
printf("%s: begin training\n", __func__);
@ -3215,6 +3226,8 @@ int main(int argc, char ** argv) {
opt_cb_data.tokens_size = train_tokens.size();
opt_cb_data.samples_begin = train_samples_begin.data();
opt_cb_data.samples_size = train_samples_size.data();
opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data();
opt_cb_data.samples_count = train_samples_size.size();
opt_cb_data.tokens_input = tokens_input;
opt_cb_data.target_probs = target_probs;