llama : restore the original load/save session implementation

Will migrate this to GGUF in the future
This commit is contained in:
Georgi Gerganov 2023-08-16 17:35:37 +03:00
parent 5b94b14d5d
commit 6412e97427
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 79 additions and 14 deletions

View file

@ -284,6 +284,12 @@ struct llama_file {
} }
} }
uint32_t read_u32() {
uint32_t ret;
read_raw(&ret, sizeof(ret));
return ret;
}
void write_raw(const void * ptr, size_t len) const { void write_raw(const void * ptr, size_t len) const {
if (len == 0) { if (len == 0) {
return; return;
@ -295,6 +301,10 @@ struct llama_file {
} }
} }
void write_u32(std::uint32_t val) const {
write_raw(&val, sizeof(val));
}
~llama_file() { ~llama_file() {
if (fp) { if (fp) {
std::fclose(fp); std::fclose(fp);
@ -4308,13 +4318,13 @@ struct llama_data_buffer_context : llama_data_context {
}; };
struct llama_data_file_context : llama_data_context { struct llama_data_file_context : llama_data_context {
FILE * file; llama_file * file;
size_t size_written = 0; size_t size_written = 0;
llama_data_file_context(FILE * f) : file(f) {} llama_data_file_context(llama_file * f) : file(f) {}
void write(const void * src, size_t size) override { void write(const void * src, size_t size) override {
fwrite(src, size, 1, file); file->write_raw(src, size);
size_written += size; size_written += size;
} }
@ -4549,13 +4559,55 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
llama_file file(path_session, "rb"); llama_file file(path_session, "rb");
GGML_UNUSED(ctx);
GGML_UNUSED(path_session);
GGML_UNUSED(tokens_out);
GGML_UNUSED(n_token_capacity);
GGML_UNUSED(n_token_count_out);
// TODO: implement with GGUF format // sanity checks
{
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return false;
}
llama_hparams session_hparams;
file.read_raw(&session_hparams, sizeof(llama_hparams));
if (session_hparams != ctx->model.hparams) {
LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__);
return false;
}
}
// load the prompt
{
const uint32_t n_token_count = file.read_u32();
if (n_token_count > n_token_capacity) {
LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return false;
}
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
}
// restore the context state
{
const size_t n_state_size_cur = file.size - file.tell();
const size_t n_state_size_max = llama_get_state_size(ctx);
if (n_state_size_cur > n_state_size_max) {
LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
return false;
}
std::vector<uint8_t> state_data(n_state_size_max);
file.read_raw(state_data.data(), n_state_size_cur);
llama_set_state_data(ctx, state_data.data());
}
return true; return true;
} }
@ -4570,11 +4622,19 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
llama_file file(path_session, "wb"); llama_file file(path_session, "wb");
GGML_UNUSED(ctx);
GGML_UNUSED(tokens);
GGML_UNUSED(n_token_count);
// TODO: implement with GGUF format file.write_u32(LLAMA_SESSION_MAGIC);
file.write_u32(LLAMA_SESSION_VERSION);
file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));
// save the prompt
file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
llama_data_file_context data_ctx(&file);
llama_copy_state_data_internal(ctx, &data_ctx);
return true; return true;
} }

View file

@ -36,6 +36,11 @@
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 1
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU. // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
#define LLAMA_SUPPORTS_GPU_OFFLOAD #define LLAMA_SUPPORTS_GPU_OFFLOAD