diff --git a/common/train.cpp b/common/train.cpp index 7ffaf94a8..e54f9b5fe 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -17,11 +17,17 @@ struct random_uniform_distribution { std::uniform_real_distribution rd; }; -struct train_state * init_train_state(int seed) { +struct train_state * init_train_state() { struct train_state * state = (struct train_state *) malloc(sizeof(struct train_state)); - memset(state, 0, sizeof(struct train_state)); + state->train_its = 0; + state->train_samples = 0; + state->train_tokens = 0; + state->train_epochs = 0; + state->shuffle_samples_hash = 0; + state->shuffle_sample_count = 0; + state->shuffle_next_sample = 0; state->shuffle_rng_state_current = ""; - state->shuffle_rng_state_next = ""; + state->shuffle_rng_state_next = ""; state->opt = (struct ggml_opt_context *) malloc(sizeof(struct ggml_opt_context)); memset(state->opt, 0, sizeof(struct ggml_opt_context)); @@ -35,10 +41,6 @@ void free_train_state(struct train_state * state) { free(state); } -struct ggml_opt_context * get_train_state_opt(struct train_state * state) { - return state->opt; -} - struct random_normal_distribution * init_random_normal_distribution( int seed, float mean, float std, float min, float max ) { @@ -741,7 +743,7 @@ struct llama_file { die_fmt("read error: %s", strerror(errno)); } if (ret != 1) { - die_fmt("unexpectedly reached end of file"); + die("unexpectedly reached end of file"); } } @@ -840,7 +842,7 @@ size_t tokenize_file( std::vector utf8_nunits; utf8_units.resize(buf.size()); utf8_nunits.resize(buf.size()); - size_t n_utf8_chars = mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size()); + mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size()); if (sample_start.size() == 0) { // tokenize all data at once @@ -1070,7 +1072,7 @@ struct train_params_common get_default_train_params_common() { return params; } -void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params) { +void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) { // fprintf(stderr, "usage: %s [options]\n", argv[0]); // fprintf(stderr, "\n"); // fprintf(stderr, "options:\n"); diff --git a/common/train.h b/common/train.h index db63a5d16..97f08964d 100644 --- a/common/train.h +++ b/common/train.h @@ -103,7 +103,7 @@ struct train_opt_callback_data { double millis_per_iter; }; -struct train_state * init_train_state(int seed); +struct train_state * init_train_state(); void free_train_state(struct train_state * state); struct train_params_common get_default_train_params_common(); diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index ae3582a54..50eda730d 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1582,7 +1582,7 @@ int main(int argc, char ** argv) { struct my_llama_lora lora; - struct train_state * train = init_train_state(params.common.seed); + struct train_state * train = init_train_state(); struct ggml_opt_context * opt = train->opt; load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 5c37776f3..861d08294 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -972,7 +972,7 @@ int main(int argc, char ** argv) { int n_vocab = model.hparams.n_vocab; int n_batch = params.common.n_batch; - struct train_state * train = init_train_state(params.common.seed); + struct train_state * train = init_train_state(); struct ggml_opt_context * opt = train->opt; struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); diff --git a/ggml.c b/ggml.c index e00324f8a..ec9ea80a4 100644 --- a/ggml.c +++ b/ggml.c @@ -15134,7 +15134,6 @@ static void ggml_compute_forward_flash_attn_back_f32( const int64_t elem_q = ggml_nelements(q); const int64_t elem_k = ggml_nelements(k); - const int64_t elem_v = ggml_nelements(v); enum ggml_type result_type = dst->type; GGML_ASSERT(ggml_blck_size(result_type) == 1);