allow empty --prompt-cache file

This allows the use of std::tmpnam(), std::tmpfile(), Python's tempfile.NamedTemporaryFile(), and similar create-empty-file API's for the user.

I switched from the C fopen API to the C++ filesystem api to get around the fact that, to the best of my knowledge, C has no portable way to get the file size above LONG_MAX, with std::ftell() returning long? fallback to std::ifstream for c++  < 17
(the project is currently targeting C++11 it seems - file_exists() and file_size() can be removed when we upgrade to c++17)
This commit is contained in:
divinity76 2024-01-28 20:13:23 +01:00 committed by GitHub
parent 35dec26cc2
commit 6c348978c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -10,6 +10,9 @@
#include <cstring>
#include <ctime>
#include <fstream>
#if __cplusplus >= 201703L
#include <filesystem>
#endif
#include <iostream>
#include <sstream>
#include <string>
@ -39,6 +42,25 @@ static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool file_exists(const std::string &path) {
#if __cplusplus >= 201703L
return std::filesystem::exists(path);
#else
std::ifstream f(path.c_str());
return f.good();
#endif
}
static uint64_t file_size(const std::string &path) {
#if __cplusplus >= 201703L
return std::filesystem::file_size(path);
#else
std::ifstream f;
f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
return static_cast<uint64_t>(f.tellg());
#endif
}
static void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model,
@ -215,24 +237,27 @@ int main(int argc, char ** argv) {
if (!path_session.empty()) {
LOG_TEE("%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str());
// fopen to check for existing session
FILE * fp = std::fopen(path_session.c_str(), "rb");
if (fp != NULL) {
std::fclose(fp);
if (!file_exists(path_session))
{
LOG_TEE("%s: session file does not exist, will create.\n", __func__);
}
else if (file_size(path_session) == 0)
{
LOG_TEE("%s: The session file is empty. A new session will be initialized.\n", __func__);
}
else
{
// The file exists and is not empty
session_tokens.resize(n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out))
{
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
session_tokens.resize(n_token_count_out);
llama_set_rng_seed(ctx, params.seed);
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
} else {
LOG_TEE("%s: session file does not exist, will create\n", __func__);
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
}
}