fix compile warnings

This commit is contained in:
xaedes 2023-09-16 22:19:46 +02:00
parent dd3e7634f0
commit 83061fbdbe
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
5 changed files with 15 additions and 14 deletions

View file

@ -17,9 +17,15 @@ struct random_uniform_distribution {
std::uniform_real_distribution<float> rd;
};
struct train_state * init_train_state(int seed) {
struct train_state * init_train_state() {
struct train_state * state = (struct train_state *) malloc(sizeof(struct train_state));
memset(state, 0, sizeof(struct train_state));
state->train_its = 0;
state->train_samples = 0;
state->train_tokens = 0;
state->train_epochs = 0;
state->shuffle_samples_hash = 0;
state->shuffle_sample_count = 0;
state->shuffle_next_sample = 0;
state->shuffle_rng_state_current = "";
state->shuffle_rng_state_next = "";
@ -35,10 +41,6 @@ void free_train_state(struct train_state * state) {
free(state);
}
struct ggml_opt_context * get_train_state_opt(struct train_state * state) {
return state->opt;
}
struct random_normal_distribution * init_random_normal_distribution(
int seed, float mean, float std, float min, float max
) {
@ -741,7 +743,7 @@ struct llama_file {
die_fmt("read error: %s", strerror(errno));
}
if (ret != 1) {
die_fmt("unexpectedly reached end of file");
die("unexpectedly reached end of file");
}
}
@ -840,7 +842,7 @@ size_t tokenize_file(
std::vector<int> utf8_nunits;
utf8_units.resize(buf.size());
utf8_nunits.resize(buf.size());
size_t n_utf8_chars = mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
if (sample_start.size() == 0) {
// tokenize all data at once
@ -1070,7 +1072,7 @@ struct train_params_common get_default_train_params_common() {
return params;
}
void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params) {
void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) {
// fprintf(stderr, "usage: %s [options]\n", argv[0]);
// fprintf(stderr, "\n");
// fprintf(stderr, "options:\n");

View file

@ -103,7 +103,7 @@ struct train_opt_callback_data {
double millis_per_iter;
};
struct train_state * init_train_state(int seed);
struct train_state * init_train_state();
void free_train_state(struct train_state * state);
struct train_params_common get_default_train_params_common();

View file

@ -1582,7 +1582,7 @@ int main(int argc, char ** argv) {
struct my_llama_lora lora;
struct train_state * train = init_train_state(params.common.seed);
struct train_state * train = init_train_state();
struct ggml_opt_context * opt = train->opt;
load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams);

View file

@ -972,7 +972,7 @@ int main(int argc, char ** argv) {
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;
struct train_state * train = init_train_state(params.common.seed);
struct train_state * train = init_train_state();
struct ggml_opt_context * opt = train->opt;
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);

1
ggml.c
View file

@ -15134,7 +15134,6 @@ static void ggml_compute_forward_flash_attn_back_f32(
const int64_t elem_q = ggml_nelements(q);
const int64_t elem_k = ggml_nelements(k);
const int64_t elem_v = ggml_nelements(v);
enum ggml_type result_type = dst->type;
GGML_ASSERT(ggml_blck_size(result_type) == 1);