use new/delete for train_state instead of malloc/free
using malloc may result in seg faults when trying to assign string fields
This commit is contained in:
parent
8721785c52
commit
ddf5ac257a
1 changed files with 6 additions and 6 deletions
|
@ -18,7 +18,7 @@ struct random_uniform_distribution {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct train_state * init_train_state() {
|
struct train_state * init_train_state() {
|
||||||
struct train_state * state = (struct train_state *) malloc(sizeof(struct train_state));
|
struct train_state * state = new struct train_state;
|
||||||
state->train_its = 0;
|
state->train_its = 0;
|
||||||
state->train_samples = 0;
|
state->train_samples = 0;
|
||||||
state->train_tokens = 0;
|
state->train_tokens = 0;
|
||||||
|
@ -29,16 +29,16 @@ struct train_state * init_train_state() {
|
||||||
state->shuffle_rng_state_current = "";
|
state->shuffle_rng_state_current = "";
|
||||||
state->shuffle_rng_state_next = "";
|
state->shuffle_rng_state_next = "";
|
||||||
|
|
||||||
state->opt = (struct ggml_opt_context *) malloc(sizeof(struct ggml_opt_context));
|
state->opt = new struct ggml_opt_context;
|
||||||
memset(state->opt, 0, sizeof(struct ggml_opt_context));
|
state->opt->ctx = NULL;
|
||||||
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
|
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
|
||||||
|
|
||||||
return state;
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
void free_train_state(struct train_state * state) {
|
void free_train_state(struct train_state * state) {
|
||||||
free(state->opt);
|
delete state->opt;
|
||||||
free(state);
|
delete state;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct random_normal_distribution * init_random_normal_distribution(
|
struct random_normal_distribution * init_random_normal_distribution(
|
||||||
|
@ -932,7 +932,7 @@ size_t tokenize_file(
|
||||||
: (i+1 < out_samples_begin.size()
|
: (i+1 < out_samples_begin.size()
|
||||||
? out_samples_begin[i+1]
|
? out_samples_begin[i+1]
|
||||||
: data_str.size());
|
: data_str.size());
|
||||||
if (utf8_units[sample_end] > 0) {
|
if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
|
||||||
// sample end is in the middle of an utf8 character.
|
// sample end is in the middle of an utf8 character.
|
||||||
// advance sample_end to the begin of the next utf8 character.
|
// advance sample_end to the begin of the next utf8 character.
|
||||||
sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
|
sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue