use new/delete for train_state instead of malloc/free

using malloc may result in seg faults when trying to assign string fields
This commit is contained in:
xaedes 2023-09-17 12:48:17 +02:00
parent 8721785c52
commit ddf5ac257a
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -18,7 +18,7 @@ struct random_uniform_distribution {
}; };
struct train_state * init_train_state() { struct train_state * init_train_state() {
struct train_state * state = (struct train_state *) malloc(sizeof(struct train_state)); struct train_state * state = new struct train_state;
state->train_its = 0; state->train_its = 0;
state->train_samples = 0; state->train_samples = 0;
state->train_tokens = 0; state->train_tokens = 0;
@ -29,16 +29,16 @@ struct train_state * init_train_state() {
state->shuffle_rng_state_current = ""; state->shuffle_rng_state_current = "";
state->shuffle_rng_state_next = ""; state->shuffle_rng_state_next = "";
state->opt = (struct ggml_opt_context *) malloc(sizeof(struct ggml_opt_context)); state->opt = new struct ggml_opt_context;
memset(state->opt, 0, sizeof(struct ggml_opt_context)); state->opt->ctx = NULL;
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM); state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
return state; return state;
} }
void free_train_state(struct train_state * state) { void free_train_state(struct train_state * state) {
free(state->opt); delete state->opt;
free(state); delete state;
} }
struct random_normal_distribution * init_random_normal_distribution( struct random_normal_distribution * init_random_normal_distribution(
@ -932,7 +932,7 @@ size_t tokenize_file(
: (i+1 < out_samples_begin.size() : (i+1 < out_samples_begin.size()
? out_samples_begin[i+1] ? out_samples_begin[i+1]
: data_str.size()); : data_str.size());
if (utf8_units[sample_end] > 0) { if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
// sample end is in the middle of an utf8 character. // sample end is in the middle of an utf8 character.
// advance sample_end to the begin of the next utf8 character. // advance sample_end to the begin of the next utf8 character.
sample_end += utf8_nunits[sample_end] - utf8_units[sample_end]; sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];