llama : restore the original load/save session implementation
Will migrate this to GGUF in the future
This commit is contained in:
parent
5b94b14d5d
commit
6412e97427
2 changed files with 79 additions and 14 deletions
86
llama.cpp
86
llama.cpp
|
@ -284,6 +284,12 @@ struct llama_file {
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t read_u32() {
|
||||
uint32_t ret;
|
||||
read_raw(&ret, sizeof(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
void write_raw(const void * ptr, size_t len) const {
|
||||
if (len == 0) {
|
||||
return;
|
||||
|
@ -295,6 +301,10 @@ struct llama_file {
|
|||
}
|
||||
}
|
||||
|
||||
void write_u32(std::uint32_t val) const {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
~llama_file() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
|
@ -4308,13 +4318,13 @@ struct llama_data_buffer_context : llama_data_context {
|
|||
};
|
||||
|
||||
struct llama_data_file_context : llama_data_context {
|
||||
FILE * file;
|
||||
llama_file * file;
|
||||
size_t size_written = 0;
|
||||
|
||||
llama_data_file_context(FILE * f) : file(f) {}
|
||||
llama_data_file_context(llama_file * f) : file(f) {}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
fwrite(src, size, 1, file);
|
||||
file->write_raw(src, size);
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
|
@ -4549,13 +4559,55 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
|||
|
||||
static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
llama_file file(path_session, "rb");
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(path_session);
|
||||
GGML_UNUSED(tokens_out);
|
||||
GGML_UNUSED(n_token_capacity);
|
||||
GGML_UNUSED(n_token_count_out);
|
||||
|
||||
// TODO: implement with GGUF format
|
||||
// sanity checks
|
||||
{
|
||||
const uint32_t magic = file.read_u32();
|
||||
const uint32_t version = file.read_u32();
|
||||
|
||||
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
|
||||
LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_hparams session_hparams;
|
||||
file.read_raw(&session_hparams, sizeof(llama_hparams));
|
||||
|
||||
if (session_hparams != ctx->model.hparams) {
|
||||
LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// load the prompt
|
||||
{
|
||||
const uint32_t n_token_count = file.read_u32();
|
||||
|
||||
if (n_token_count > n_token_capacity) {
|
||||
LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
||||
return false;
|
||||
}
|
||||
|
||||
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
||||
*n_token_count_out = n_token_count;
|
||||
}
|
||||
|
||||
// restore the context state
|
||||
{
|
||||
const size_t n_state_size_cur = file.size - file.tell();
|
||||
const size_t n_state_size_max = llama_get_state_size(ctx);
|
||||
|
||||
if (n_state_size_cur > n_state_size_max) {
|
||||
LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data(n_state_size_max);
|
||||
file.read_raw(state_data.data(), n_state_size_cur);
|
||||
|
||||
llama_set_state_data(ctx, state_data.data());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -4570,11 +4622,19 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
|
|||
|
||||
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
||||
llama_file file(path_session, "wb");
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(tokens);
|
||||
GGML_UNUSED(n_token_count);
|
||||
|
||||
// TODO: implement with GGUF format
|
||||
file.write_u32(LLAMA_SESSION_MAGIC);
|
||||
file.write_u32(LLAMA_SESSION_VERSION);
|
||||
|
||||
file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));
|
||||
|
||||
// save the prompt
|
||||
file.write_u32((uint32_t) n_token_count);
|
||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_data_file_context data_ctx(&file);
|
||||
llama_copy_state_data_internal(ctx, &data_ctx);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
7
llama.h
7
llama.h
|
@ -34,7 +34,12 @@
|
|||
# define DEPRECATED(func, hint) func
|
||||
#endif
|
||||
|
||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||
|
||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 1
|
||||
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
||||
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue