diff --git a/llama.cpp b/llama.cpp index 592e7e48a..38a2d5ba8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -284,6 +284,12 @@ struct llama_file { } } + uint32_t read_u32() { + uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + void write_raw(const void * ptr, size_t len) const { if (len == 0) { return; @@ -295,6 +301,10 @@ struct llama_file { } } + void write_u32(std::uint32_t val) const { + write_raw(&val, sizeof(val)); + } + ~llama_file() { if (fp) { std::fclose(fp); @@ -4308,13 +4318,13 @@ struct llama_data_buffer_context : llama_data_context { }; struct llama_data_file_context : llama_data_context { - FILE * file; + llama_file * file; size_t size_written = 0; - llama_data_file_context(FILE * f) : file(f) {} + llama_data_file_context(llama_file * f) : file(f) {} void write(const void * src, size_t size) override { - fwrite(src, size, 1, file); + file->write_raw(src, size); size_written += size; } @@ -4549,13 +4559,55 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); - GGML_UNUSED(ctx); - GGML_UNUSED(path_session); - GGML_UNUSED(tokens_out); - GGML_UNUSED(n_token_capacity); - GGML_UNUSED(n_token_count_out); - // TODO: implement with GGUF format + // sanity checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { + LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + return false; + } + + llama_hparams session_hparams; + file.read_raw(&session_hparams, sizeof(llama_hparams)); + + if (session_hparams != ctx->model.hparams) { + LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); + return false; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return false; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t n_state_size_cur = file.size - file.tell(); + const size_t n_state_size_max = llama_get_state_size(ctx); + + if (n_state_size_cur > n_state_size_max) { + LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); + return false; + } + + std::vector state_data(n_state_size_max); + file.read_raw(state_data.data(), n_state_size_cur); + + llama_set_state_data(ctx, state_data.data()); + } + return true; } @@ -4570,11 +4622,19 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { llama_file file(path_session, "wb"); - GGML_UNUSED(ctx); - GGML_UNUSED(tokens); - GGML_UNUSED(n_token_count); - // TODO: implement with GGUF format + file.write_u32(LLAMA_SESSION_MAGIC); + file.write_u32(LLAMA_SESSION_VERSION); + + file.write_raw(&ctx->model.hparams, sizeof(llama_hparams)); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_file_context data_ctx(&file); + llama_copy_state_data_internal(ctx, &data_ctx); return true; } diff --git a/llama.h b/llama.h index 7bae54d6a..2e407a1db 100644 --- a/llama.h +++ b/llama.h @@ -34,7 +34,12 @@ # define DEPRECATED(func, hint) func #endif -#define LLAMA_DEFAULT_SEED 0xFFFFFFFF +#define LLAMA_DEFAULT_SEED 0xFFFFFFFF + +#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' + +#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN +#define LLAMA_SESSION_VERSION 1 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) // Defined when llama.cpp is compiled with support for offloading model layers to GPU.