diff --git a/common/sampling.cpp b/common/sampling.cpp index b4063fe31..123c6b2a7 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -110,12 +110,13 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { } struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl) { - gpt_sampler * result = new gpt_sampler(); - - result->grmr = llama_constraint_cp(gsmpl->grmr); - result->smpl = llama_sampler_cp(gsmpl->smpl); - - return result; + return new gpt_sampler { + /* .params = */ gsmpl->params, + /* .bias = */ llama_constraint_cp(gsmpl->bias), + /* .pnlt = */ llama_constraint_cp(gsmpl->pnlt), + /* .grmr = */ llama_constraint_cp(gsmpl->grmr), + /* .smpl = */ llama_sampler_cp(gsmpl->smpl) + }; } void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) { @@ -145,7 +146,7 @@ llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { } void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) { - llama_print_timings(ctx, gsmpl->smpl); + llama_print_timings(ctx, gsmpl ? gsmpl->smpl : nullptr); } static llama_token gpt_sampler_sample(