llama : serialize rng into minimum amount of space required

This commit is contained in:
David Friehs 2024-01-08 08:58:06 +01:00
parent e872af8dee
commit 69d44e3e3f

View file

@ -10228,14 +10228,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
std::ostringstream rng_ss;
rng_ss << ctx->rng;
const size_t rng_size = rng_ss.str().size();
char rng_buf[LLAMA_MAX_RNG_STATE];
const std::string & rng_str = rng_ss.str();
const size_t rng_size = rng_str.size();
memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
data_ctx->write(&rng_size, sizeof(rng_size));
data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE);
data_ctx->write(&rng_size, sizeof(rng_size));
data_ctx->write(rng_str.data(), rng_size);
}
// copy logits
@ -10356,13 +10355,13 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
// set rng
{
size_t rng_size;
char rng_buf[LLAMA_MAX_RNG_STATE];
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE;
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
std::istringstream rng_ss;
rng_ss.str(std::string(&rng_buf[0], rng_size));
std::string rng_str((char *)inp, rng_size); inp += rng_size;
std::istringstream rng_ss(rng_str);
rng_ss >> ctx->rng;
GGML_ASSERT(!rng_ss.fail());