diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index aa8fbd6b6..8e06838ad 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -202,9 +202,16 @@ struct my_llama_model { std::vector layers; - uint32_t train_its = 0; - uint32_t train_samples = 0; - uint32_t train_tokens = 0; + uint64_t train_its = 0; + uint64_t train_samples = 0; + uint64_t train_tokens = 0; + uint64_t train_epochs = 0; + + size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes) + std::string shuffle_rng_state_current; + std::string shuffle_rng_state_next; + size_t shuffle_sample_count; + size_t shuffle_next_sample; }; // gguf constants @@ -242,13 +249,19 @@ static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"; static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"; -static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"; -static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"; -static const char * LLM_KV_TRAINING_TYPE = "training.type"; -static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; -static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; -static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; -static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; +static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"; +static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"; +static const char * LLM_KV_TRAINING_TYPE = "training.type"; +static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; +static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; +static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; +static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; +static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count"; +static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash"; +static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash"; +static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state"; +static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count"; +static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample"; // gguf constants (sync with gguf.py) @@ -312,6 +325,7 @@ static void init_model(struct my_llama_model * model) { model->train_its = 0; model->train_samples = 0; model->train_tokens = 0; + model->train_epochs = 0; std::vector tn_buf; tn_buf.resize(GGML_MAX_NAME); @@ -623,148 +637,450 @@ static struct ggml_tensor * llama_build_train_graphs( return t36; } -static void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { - int n_tokens = tokens_input->ne[0]; - int n_vocab = target_logits->ne[0]; +static int get_example_targets_batch( + struct llama_context * lctx, + const size_t * samples_begin, + const size_t * samples_size, + size_t samples_count, + const llama_token * train_data, + size_t n_train_data, + int example_id, + struct ggml_tensor * tokens_input, + struct ggml_tensor * target_probs, + bool separate_with_eos, + bool separate_with_bos, + bool fill_with_next_samples) { - size_t sample = train_samples[example_id % n_train_samples]; - GGML_ASSERT(sample+n_tokens-1 < n_train_data); - - ggml_set_f32(target_logits, -1.0f/n_vocab); - ggml_set_f32(target_probs, 0.0f); - ggml_set_i32_1d(tokens_input, 0, llama_token_bos(lctx)); - for (int i=1; in_dims == 2); - GGML_ASSERT(target_logits->n_dims == 3); GGML_ASSERT(target_probs->n_dims == 3); - int n_vocab = target_logits->ne[0]; + int n_vocab = target_probs->ne[0]; int n_tokens = tokens_input->ne[0]; int n_batch = tokens_input->ne[1]; - GGML_ASSERT(n_tokens == target_logits->ne[1]); - GGML_ASSERT(n_batch == target_logits->ne[2]); GGML_ASSERT(n_vocab == target_probs->ne[0]); GGML_ASSERT(n_tokens == target_probs->ne[1]); GGML_ASSERT(n_batch == target_probs->ne[2]); - ggml_set_f32(target_logits, -1.0f/n_vocab); + int used_samples = 0; + ggml_set_f32(target_probs, 0.0f); + int bos = llama_token_bos(lctx); + int eos = llama_token_eos(lctx); // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); for (int k=0; k= sample_size && fill_with_next_samples) { + if (!sample_separation_eos) { + // insert eos token to separate samples + sample_separation_eos = true; + } else if (!sample_separation_bos) { + // insert bos token to separate samples + sample_separation_bos = true; + token = bos; + } else { + // sample separation is done, continue with next sample + sample_separation_eos = !separate_with_eos; + sample_separation_bos = !separate_with_bos; + sample_offs = 0; + sample_idx = (example_id + used_samples) % samples_count; + sample_begin = samples_begin[sample_idx]; + sample_size = samples_size[sample_idx]; + ++used_samples; + } + } + // note: no else-if here + if (sample_offs < sample_size) { + token = clamp(train_data[sample_begin+sample_offs], 0, n_vocab-1); + ++sample_offs; + } + ggml_set_f32_nd(target_probs, token, i, k, 0, +1.0f); + if (i+1& out) { - FILE * fp = std::fopen(filename, "rb"); - if (fp == NULL) { - return 0; +#ifdef __GNUC__ +#ifdef __MINGW32__ +__attribute__((format(gnu_printf, 1, 2))) +#else +__attribute__((format(printf, 1, 2))) +#endif +#endif +static std::string format(const char * fmt, ...) { + va_list ap, ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +struct llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + llama_file(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + size = 0; + } else { + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } } + size_t tell() const { #ifdef _WIN32 - GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_END) == 0); + __int64 ret = _ftelli64(fp); #else - GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_END) == 0); + long ret = std::ftell(fp); #endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } - size_t size = 0; + void seek(size_t offset, int whence) { #ifdef _WIN32 - __int64 ret = _ftelli64(fp); - size = ret; + int ret = _fseeki64(fp, (__int64) offset, whence); #else - long ret = std::ftell(fp); - size = ret; + int ret = std::fseek(fp, (long) offset, whence); #endif + GGML_ASSERT(ret == 0); // same + } -#ifdef _WIN32 - GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_SET) == 0); -#else - GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_SET) == 0); -#endif + void read_raw(void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, size, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error(std::string("unexpectedly reached end of file")); + } + } + + std::uint32_t read_u32() { + std::uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + std::string read_string(std::uint32_t len) { + std::vector chars(len); + read_raw(chars.data(), len); + return std::string(chars.data(), len); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +}; + +static size_t utf8_len(char src) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(src) >> 4; + return lookup[highbits]; +} + +// mark each byte with its utf8 unit number. +// returns the number of utf8 characters. +// e.g. when bytes == '\x61\xD0\xB0\x62', +// then utf8_units will become [0,0,1,0] +// utf8_nunits will become [1,2,2,1] and 3 is returned. +// bytes where utf8_units is zero, are the begin of an utf8 character. +static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) { + size_t offs = 0; + size_t count_utf8 = 0; + while(offs < count) { + size_t len = utf8_len(bytes[offs]); + for (size_t i=0; i & out_tokens, + std::vector & out_samples_begin, + std::vector & out_samples_size) { + struct llama_file f(filename, "rb"); + + if (f.size == 0) { + out_tokens.clear(); + out_samples_begin.clear(); + out_samples_size.clear(); + printf("%s: warning: empty or not existing training data file '%s'\n", + __func__, filename); + return out_tokens.size(); + } + + // account for possible leading whitespace that will be added by tokenizer + // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12] + const int n_max_tokens_overhead = 1; std::vector buf; - buf.resize(size+1); - out.resize(size+1); + buf.resize(f.size+1); - if (std::fread(buf.data(), size, 1, fp) != 1) { - die("unexpectedly reached end of file"); - } - if (ferror(fp)) { - die_fmt("fread failed: %s", strerror(errno)); - } + f.read_raw(buf.data(), f.size); + buf[f.size] = '\0'; - buf[size] = '\0'; + std::vector utf8_units; + std::vector utf8_nunits; + utf8_units.resize(buf.size()); + utf8_nunits.resize(buf.size()); + size_t n_utf8_chars = mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size()); - int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); - if (n_tokens < 0) { - out.resize(-n_tokens); - n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); - } - GGML_ASSERT(n_tokens >= 0); - out.resize(n_tokens); + if (sample_start.size() == 0) { + // tokenize all data at once + out_tokens.resize(buf.size() + n_max_tokens_overhead); - bool verify = false; - if (verify) { - const char * in = buf.data(); - const char * end = buf.data() + buf.size(); - for (int i = 0; i < (int) out.size(); ++i) { - std::string s = llama_token_to_piece(lctx, out[i]); - int len = s.length(); - if (in >= end) { - printf("%s: unexpected end of original text.\n", __func__); - break; + int n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), out_tokens.size(), false); + if (n_tokens < 0) { + out_tokens.resize(-n_tokens); + n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), out_tokens.size(), false); + } + if (n_tokens >= 0) { + out_tokens.resize(n_tokens); + } + + // generate sample starts at all token positions + out_samples_begin.clear(); + out_samples_begin.push_back(0); + out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size())); + size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0; + for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) { + out_samples_begin.push_back(sample_begin); + out_samples_size.push_back(context_length); + } + } else { + // split data into samples and tokenize each sample + std::string data_str(buf.data(), buf.size()-1); + out_samples_begin.clear(); + out_samples_size.clear(); + out_tokens.clear(); + + // find all positions of pattern sample_start + size_t sample_begin = data_str.find(sample_start, 0); + while (sample_begin != std::string::npos) { + out_samples_begin.push_back(sample_begin); + const size_t search_start = sample_begin + sample_start.size(); + sample_begin = data_str.find(sample_start, search_start); + } + if (out_samples_begin.size() == 0) { + printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n", + __func__, sample_start.c_str()); + out_samples_begin.push_back(0); + } + + out_samples_size.resize(out_samples_begin.size(), 0); + + std::vector buf_sample; + std::vector tok_sample; + + const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size()); + size_t found_too_big_sample = 0; + size_t found_too_small_sample = 0; + size_t found_empty_sample = 0; + size_t found_min_sample_size = SIZE_MAX; + size_t found_max_sample_size = 0; + + size_t max_token_text_size = 0; + int n_vocab = llama_n_vocab(lctx); + for (llama_token token=0; token < n_vocab; ++token) { + max_token_text_size = std::max( + max_token_text_size, + strlen(llama_token_get_text(lctx, token))); + } + + // upper bound of context byte length. + // strings with this byte length should always tokenize to at least context_length tokens. + size_t context_byte_len = max_token_text_size*context_length; + + for (unsigned i=0; i 0) { + // sample end is in the middle of an utf8 character. + // advance sample_end to the begin of the next utf8 character. + sample_end += utf8_nunits[sample_end] - utf8_units[sample_end]; } - const bool matches = (strncmp(in, s.c_str(), len) == 0); - if (matches) { - in += len; + size_t sample_size = sample_end - sample_begin; + if (sample_size == 0) { + ++found_empty_sample; + } + + if (sample_size > 0) { + // llama_tokenize expects zero terminated string, + // copy sample into buffer and zero terminate it. + buf_sample.resize(sample_size+1); + memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size); + buf_sample[sample_size] = '\0'; + + // printf("sample: '%s'\n", buf_sample.data()); + + // tokenize the sample + tok_sample.resize(buf_sample.size() + n_max_tokens_overhead); + int n_tokens = llama_tokenize(lctx, + buf_sample.data(), + tok_sample.data(), + tok_sample.size(), false); + if (n_tokens < 0) { + tok_sample.resize(-n_tokens); + n_tokens = llama_tokenize(lctx, + buf_sample.data(), + tok_sample.data(), + tok_sample.size(), false); + GGML_ASSERT(n_tokens >= 0); + } + GGML_ASSERT(n_tokens <= tok_sample.size()); + + if ((size_t) n_tokens > context_length) { + ++found_too_big_sample; + } else if ((size_t) n_tokens < context_length) { + ++found_too_small_sample; + } + found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens); + found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens); + + // write out tokens, start and size of sample + // overwrite the string start position with the token start position + out_samples_begin[i] = out_tokens.size(); + out_samples_size[i] = (size_t) n_tokens; + out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens); } else { - printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s.c_str()); + out_samples_begin[i] = out_tokens.size(); + out_samples_size[i] = 0; } + + } + if (found_too_big_sample > 0) { + printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n", + __func__, found_too_big_sample, found_max_sample_size, context_length); + } + + if (found_too_small_sample > 0) { + printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n", + __func__, found_too_small_sample, found_min_sample_size, context_length); + } + + if (found_empty_sample) { + printf("%s: warning: found %zu empty samples.\n", + __func__, found_empty_sample); } } + printf("%s: total number of samples: %zu\n", + __func__, out_samples_begin.size()); - return n_tokens; + GGML_ASSERT(out_samples_begin.size() == out_samples_size.size()); + + return out_tokens.size(); } -static void shuffle_ints(int * begin, int * end) { - if (end <= begin) return; - int max=begin[0]; - for (int i=1; i max) { - max = begin[i]; +static void mt19937_set_state(std::mt19937& rng, const std::string& rng_state) { + std::stringstream s_rng_state; + s_rng_state.imbue(std::locale::classic()); + s_rng_state.exceptions(std::stringstream::failbit); + s_rng_state.str(rng_state); + s_rng_state >> rng; +} + +static std::string mt19937_get_state(const std::mt19937& rng) { + std::stringstream s_rng_state; + s_rng_state.imbue(std::locale::classic()); + s_rng_state << rng; + return s_rng_state.str(); +} + +static std::string mt19937_seed_to_state(unsigned seed) { + std::mt19937 rng(seed); + return mt19937_get_state(rng); +} + +static std::string shuffle_samples( + const std::string & rng_state, + const size_t * begins, + const size_t * sizes, + size_t * shuffled_begins, + size_t * shuffled_sizes, + size_t count) { + if (count == 0) return rng_state; + + std::mt19937 rng; + mt19937_set_state(rng, rng_state); + + // sort indices by random value for each index + std::vector idcs; + { + std::vector rnd; + idcs.resize(count); + rnd.resize(count); + for (unsigned i=0; i vals; - vals.resize(max+1); - for (int i=0; i= 0) { uint32_t file_version = 0xFFFFFFFFu; GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); - GGML_ASSERT(file_version == 0); + GGML_ASSERT(file_version <= 1); std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL; GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL); - GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); - GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); + if (file_version == 0) { + + GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); + GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); + + } else if (file_version == 1) { + + GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT); + GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT); + GGUF_GET_KEY(fctx, model->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT); + + GGUF_GET_KEY(fctx, model->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH); + GGUF_GET_KEY(fctx, model->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE); + GGUF_GET_KEY(fctx, model->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, model->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE); + } load_opt_context_gguf(fctx, f_ggml_ctx, opt); } else { @@ -1196,11 +1527,17 @@ static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) { save_llama_model_gguf(fctx, fn_vocab_model, model); - gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 0); + gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1); gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL); - gguf_set_val_u32(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its); - gguf_set_val_u32(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples); - gguf_set_val_u32(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, model->train_epochs); + + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) model->shuffle_samples_hash); + gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, model->shuffle_rng_state_current.c_str()); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) model->shuffle_sample_count); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) model->shuffle_next_sample); save_opt_context_gguf(fctx, opt); } @@ -1283,12 +1620,21 @@ struct train_params { int print_info_interval; - bool samples_start_after_nl; bool use_adam; bool use_flash; bool use_checkpointing; bool use_alloc; + std::string sample_start; + bool include_sample_start; + bool escape; + bool overlapping_samples; + bool fill_with_next_samples; + bool separate_with_eos; + bool separate_with_bos; + + bool force_reshuffle; + // only adam int warmup; int cos_decay_steps; @@ -1347,12 +1693,20 @@ struct train_params get_default_train_params() { params.print_info_interval = 1; - params.samples_start_after_nl = false; params.use_adam = true; params.use_flash = true; params.use_checkpointing = true; params.use_alloc = true; + params.sample_start = ""; + params.include_sample_start = false; + params.escape = false; + params.overlapping_samples = false; + params.fill_with_next_samples = false; + params.separate_with_eos = false; + params.separate_with_bos = true; + params.force_reshuffle = false; + params.opt_past = 0; params.opt_delta = 1e-5f; params.opt_max_no_improvement = 0; @@ -1408,7 +1762,16 @@ static void train_print_usage(int /*argc*/, char ** argv, const struct train_par fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation); fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval); - fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off"); + fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str()); + fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n"); + fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); + fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n"); + fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n"); + fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : ""); + fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : ""); + fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); + fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); + fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n"); fprintf(stderr, " --use-adam Use Adam optimizer (default)\n"); fprintf(stderr, " --no-flash Don't use flash attention \n"); @@ -1586,8 +1949,30 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par break; } params->print_info_interval = std::stoi(argv[i]); - } else if (arg == "--samples-after-nl") { - params->samples_start_after_nl = true; + } else if (arg == "--sample-start") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->sample_start = std::string(argv[i]); + } else if (arg == "--escape") { + params->escape = true; + } else if (arg == "--include-sample-start") { + params->include_sample_start = true; + } else if (arg == "--overlapping-samples") { + params->overlapping_samples = true; + } else if (arg == "--fill-with-next-samples") { + params->fill_with_next_samples = true; + } else if (arg == "--separate-with-eos") { + params->separate_with_eos = true; + } else if (arg == "--separate-with-bos") { + params->separate_with_bos = true; + } else if (arg == "--no-separate-with-eos") { + params->separate_with_eos = false; + } else if (arg == "--no-separate-with-bos") { + params->separate_with_bos = false; + } else if (arg == "--force-reshuffle") { + params->force_reshuffle = true; } else if (arg == "--use-lbfgs") { params->use_adam = false; } else if (arg == "--use-adam") { @@ -1742,6 +2127,9 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par train_print_usage(argc, argv, &default_params); exit(1); } + if (params->escape) { + process_escapes(params->sample_start); + } return true; } @@ -1754,14 +2142,42 @@ struct opt_callback_data { int last_save_iter; llama_token * tokens_data; size_t tokens_size; - int * samples_data; - size_t samples_size; - int shuffle_countdown; + size_t * samples_begin; + size_t * samples_size; + size_t * shuffled_samples_begin; + size_t * shuffled_samples_size; + size_t samples_count; struct ggml_tensor * tokens_input; struct ggml_tensor * target_logits; struct ggml_tensor * target_probs; + int first_iter; + int64_t last_time; + double millis_per_iter; }; +static void print_duration(double fmillis) { + if (fmillis < 1000.0f) { + printf("%.1fms", (float) fmillis); + return; + } + const int64_t one_sec = 1000; + const int64_t one_min = one_sec * 60; + const int64_t one_hour = one_min * 60; + const int64_t one_day = one_hour * 24; + + int64_t millis = (int64_t) fmillis; + int64_t days = millis/one_day; + int64_t hours = (millis - days*one_day)/one_hour; + int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min; + int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec; + + // to print int64_t either cast to (long long int) or use macro PRId64 from + if (days > 0) { + printf("%lldd ", (long long int) days); + } + printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds); +} + static void opt_callback(void * vdata, int accum_step, float * sched) { struct opt_callback_data * data = (struct opt_callback_data *) vdata; struct train_params * params = data->params; @@ -1770,6 +2186,27 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { int n_ctx = params->n_ctx; if (accum_step == 0) { + // time measurement + int64_t now = ggml_time_ms(); + if (now > data->last_time && opt->iter > data->first_iter) { + double dt = now - data->last_time; + if (data->millis_per_iter == 0.0) { + data->millis_per_iter = dt; + } else { + const double gain = 0.7; + data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain; + } + } + + double remaining_millis = 0.0; + if (data->millis_per_iter > 0.0) { + const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter; + const int done_iter = opt->iter - data->first_iter; + const int remaining_iter = n_iter - done_iter; + remaining_millis = remaining_iter * data->millis_per_iter; + } + + // file saving const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); if (save_now) { int new_iters = opt->iter - data->last_save_iter; @@ -1789,6 +2226,9 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { data->last_save_iter = opt->iter; } + // exclude file saving from time measurement, by measuring last_time after saving + data->last_time = ggml_time_ms(); + *sched = (opt->iter < params->warmup) ? (float) opt->iter / (float) params->warmup : cosine_decay_restart( @@ -1800,32 +2240,79 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { float min_sched = params->adam_min_alpha / params->adam_alpha; *sched = min_sched + *sched * (1.0f - min_sched); - int impr_plot = std::isnan(opt->loss_after) ? 0 : -std::lround(1 + (opt->loss_before - opt->loss_after) * 10.0f); - printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0); + int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); + if (impr_plot > 0) impr_plot = 0; + if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0; + printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f", + __func__, opt->iter, std::min(1+data->model->shuffle_next_sample, data->model->shuffle_sample_count), data->model->shuffle_sample_count, + *sched, opt->loss_after); - } - if (data->shuffle_countdown < n_batch) { - printf("%s: reshuffle samples\n", __func__); - shuffle_ints(data->samples_data, data->samples_data + data->samples_size); - for (int i = 0; i < (int) data->samples_size; ++i) { - GGML_ASSERT(data->samples_data[i]+params->n_ctx-1 < (int) data->tokens_size); + if (data->millis_per_iter > 0) { + printf(" dt="); + print_duration(data->millis_per_iter); + printf(" eta="); + print_duration(remaining_millis); } - data->shuffle_countdown = data->samples_size; + + float improvement = opt->loss_before - opt->loss_after; + const float plot_scale = 10.0f; + int bar_len = (int)(1 + improvement*plot_scale + 0.5); + printf(" |"); + for (int i=0; i"); + printf("\n"); } - get_example_targets_batch( + int used_samples = get_example_targets_batch( data->lctx, - data->samples_data, - data->samples_size, + data->shuffled_samples_begin, + data->shuffled_samples_size, + data->samples_count, data->tokens_data, data->tokens_size, - opt->iter*params->n_gradient_accumulation + accum_step, + data->model->shuffle_next_sample, data->tokens_input, - data->target_logits, - data->target_probs); + data->target_probs, + params->separate_with_eos, + params->separate_with_bos, + params->fill_with_next_samples); - data->shuffle_countdown -= n_batch; + data->model->shuffle_next_sample += used_samples; + + if (data->model->shuffle_next_sample >= data->model->shuffle_sample_count) { + ++data->model->train_epochs; + printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) data->model->train_epochs); + // note: we may have used some samples from the current shuffling more than once + data->model->shuffle_rng_state_current = data->model->shuffle_rng_state_next; + data->model->shuffle_rng_state_next = shuffle_samples( + data->model->shuffle_rng_state_current, + data->samples_begin, + data->samples_size, + data->shuffled_samples_begin, + data->shuffled_samples_size, + data->samples_count); + data->model->shuffle_next_sample = 0; + } + +} + +static size_t hash_combine(size_t h1, size_t h2) { + return h1 ^ (h2 << 1); +} + +static size_t compute_samples_hash(const char* fn, const size_t* samples_begin, const size_t* samples_size, size_t sample_count) { + std::hash h_string; + std::hash h_ull; + size_t h = h_string(std::string(fn)); + h = hash_combine(h, h_ull((unsigned long long) sample_count)); + for (size_t i=0; i< sample_count; ++i) { + h = hash_combine(h, h_ull((unsigned long long) samples_begin[i])); + h = hash_combine(h, h_ull((unsigned long long) samples_size[i])); + } + return h; } int main(int argc, char ** argv) { @@ -1847,13 +2334,6 @@ int main(int argc, char ** argv) { struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params); struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); - printf("%s: tokenize training data\n", __func__); - std::vector train_tokens; - if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) { - fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, params.fn_train_data); - } - printf("%s: number of training tokens: %d\n", __func__, (int) train_tokens.size()); - struct my_llama_model model; model.hparams.n_vocab = llama_n_vocab(lctx); model.hparams.n_ctx = params.n_ctx; @@ -1869,24 +2349,6 @@ int main(int argc, char ** argv) { print_params(&model.hparams); - std::vector token_noccurs; - std::vector token_notavail; - token_noccurs.resize(model.hparams.n_vocab, 0); - token_notavail.resize(model.hparams.n_vocab, true); - for (int i = 0; i < (int) train_tokens.size(); ++i) { - ++token_noccurs[train_tokens[i]]; - token_notavail[train_tokens[i]] = false; - } - - std::vector token_freq; - token_freq.resize(model.hparams.n_vocab, 0); - int n_unique_tokens = 0; - for (int i = 0; i < (int) token_noccurs.size(); ++i) { - token_freq[i] = (float) token_noccurs[i] / (float) train_tokens.size(); - n_unique_tokens += (token_noccurs[i] > 0) ? 1 : 0; - } - printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens); - struct ggml_init_params lcparams; lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb); lcparams.mem_buffer = NULL; @@ -1965,19 +2427,48 @@ int main(int argc, char ** argv) { alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment); } - GGML_ASSERT(n_tokens < (int) train_tokens.size()); - std::vector train_samples; - train_samples.push_back(0); - for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) { - if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl(lctx))) { - train_samples.push_back(i); - } - } - shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size()); - for (int i = 0; i < (int) train_samples.size(); ++i) { - GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size()); - } + std::vector train_tokens; + std::vector train_samples_begin; + std::vector train_samples_size; + printf("%s: tokenize training data\n", __func__); + tokenize_file(lctx, + params.fn_train_data, + params.sample_start, + params.include_sample_start, + params.overlapping_samples, + n_tokens, + train_tokens, + train_samples_begin, + train_samples_size); + GGML_ASSERT(train_samples_begin.size() == train_samples_size.size()); + printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size()); + + size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); + const bool changed_train_data = (shuffle_samples_hash != model.shuffle_samples_hash) || (model.shuffle_sample_count != train_samples_size.size()); + if (changed_train_data) { + printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__); + } + if (params.force_reshuffle) { + printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__); + } + if ((model.shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { + model.shuffle_rng_state_current = mt19937_seed_to_state(params.seed); + model.shuffle_sample_count = train_samples_size.size(); + model.shuffle_next_sample = 0; + model.shuffle_samples_hash = shuffle_samples_hash; + } + std::vector train_shuffled_samples_begin; + std::vector train_shuffled_samples_size; + train_shuffled_samples_begin.resize(train_samples_begin.size()); + train_shuffled_samples_size.resize(train_samples_size.size()); + model.shuffle_rng_state_next = shuffle_samples( + model.shuffle_rng_state_current, + train_samples_begin.data(), + train_samples_size.data(), + train_shuffled_samples_begin.data(), + train_shuffled_samples_size.data(), + train_samples_size.size()); printf("%s: begin training\n", __func__); struct opt_callback_data opt_cb_data; @@ -1988,22 +2479,21 @@ int main(int argc, char ** argv) { opt_cb_data.last_save_iter = opt->iter; opt_cb_data.tokens_data = train_tokens.data(); opt_cb_data.tokens_size = train_tokens.size(); - opt_cb_data.samples_data = train_samples.data(); - opt_cb_data.samples_size = train_samples.size(); - opt_cb_data.shuffle_countdown = train_samples.size(); + opt_cb_data.samples_begin = train_samples_begin.data(); + opt_cb_data.samples_size = train_samples_size.data(); + opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); + opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data(); + opt_cb_data.samples_count = train_samples_size.size(); opt_cb_data.tokens_input = NULL; opt_cb_data.target_logits = NULL; opt_cb_data.target_probs = NULL; + opt_cb_data.first_iter = opt->iter; + opt_cb_data.last_time = ggml_time_ms(); + opt_cb_data.millis_per_iter = 0.0; int64_t t0 = ggml_time_ms(); for (int ex = 0; ex < params.n_examples; ++ex) { - if (ex*n_batch >= (int) train_samples.size()) { - shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size()); - for (int i = 0; i < (int) train_samples.size(); ++i) { - GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size()); - } - } struct ggml_init_params cparams = { compute_size, // mem_size