This commit is contained in:
Marko Tasic 2024-08-27 10:30:31 +01:00 committed by GitHub
commit 0f398dddb3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 135 additions and 7 deletions

View file

@ -12,6 +12,7 @@ This example program allows you to use various LLaMA language models easily and
6. [Generation Flags](#generation-flags)
7. [Performance Tuning and Memory Options](#performance-tuning-and-memory-options)
8. [Additional Options](#additional-options)
9. [Shared Library](#shared-library)
## Quick Start
@ -317,3 +318,82 @@ These options provide extra functionality and customization when running the LLa
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
- `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache.
## Shared Library
To build `llama-cli` as a shared library, run the following command from the root directory of the repository:
```bash
CXXFLAGS="-DSHARED_LIB" LDFLAGS="-shared -o libllama-cli.so" make llama-cli
```
You will receive the function `llama_cli_main`, which can be invoked via FFI with the standard options available to `llama-cli`:
```c
int llama_cli_main(int argc, char ** argv);
```
To enhance the management of custom file descriptors for STDOUT and STDERR, and to intercept token printing, we provide four functions:
```c
void llama_set_stdout(FILE* f);
void llama_set_stderr(FILE* f);
void llama_set_fprintf(int (*func)(FILE*, const char*, ...));
void llama_set_fflush(int (*func)(FILE*));
```
This is particularly beneficial if you need to use `libllama-cli.so` through FFI in other programming languages without altering the default STDOUT and STDERR file descriptors.
Here's a Python example that independently handles printing tokens without relying on STDOUT and STDERR:
```python
from ctypes import *
#
# open shared library
#
lib = CDLL('./libllama-cli.so')
lib.llama_cli_main.argtypes = [c_int, POINTER(c_char_p)]
lib.llama_cli_main.restype = c_int
#
# redefine fprintf and fflush
#
@CFUNCTYPE(c_int, c_void_p, c_char_p, c_char_p)
def fprintf(file_obj, fmt, *args):
content = fmt.decode('utf-8') % tuple(arg.decode('utf-8') for arg in args)
print(content, flush=True, end='')
size = len(content)
return size
@CFUNCTYPE(c_int, c_void_p)
def fflush(file_obj):
print(flush=True, end='')
return 0
lib.llama_set_fprintf(fprintf)
lib.llama_set_fflush(fflush)
#
# generate and print token by token
#
argv: list[bytes] = [
b'llama-cli',
b'-m',
b'models/7B/ggml-model.bin',
b'--no-display-prompt',
b'--simple-io',
b'--log-disable',
b'-p',
b'What is cosmos?',
]
argc = len(argv)
argv = (c_char_p * argc)(*argv)
res = lib.llama_cli_main(argc, argv)
assert res == 0
```
You can capture generated tokens in the Python implementation of `fprintf` function without actually printing them, if necessary.

View file

@ -39,6 +39,39 @@ static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool need_insert_eot = false;
static FILE *llama_stdout = stdout;
static FILE *llama_stderr = stderr;
static int (*llama_fprintf)(FILE*, const char*, ...) = fprintf;
static int (*llama_fflush)(FILE*) = fflush;
#ifdef __cplusplus
extern "C" {
#endif
void llama_set_stdout(FILE* f);
void llama_set_stderr(FILE* f);
void llama_set_fprintf(int (*func)(FILE*, const char*, ...));
void llama_set_fflush(int (*func)(FILE*));
void llama_set_stdout(FILE* f) {
llama_stdout = f;
}
void llama_set_stderr(FILE* f) {
llama_stderr = f;
}
void llama_set_fprintf(int (*func)(FILE*, const char*, ...)) {
llama_fprintf = func;
}
void llama_set_fflush(int (*func)(FILE*)) {
llama_fflush = func;
}
#ifdef __cplusplus
}
#endif
static bool file_exists(const std::string & path) {
std::ifstream f(path.c_str());
@ -65,7 +98,7 @@ static void write_logfile(
const bool success = fs_create_directory_with_parents(params.logdir);
if (!success) {
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
llama_fprintf(llama_stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
__func__, params.logdir.c_str());
return;
}
@ -74,7 +107,7 @@ static void write_logfile(
FILE * logfile = fopen(logfile_path.c_str(), "w");
if (logfile == NULL) {
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
llama_fprintf(llama_stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
return;
}
@ -128,7 +161,18 @@ static std::string chat_add_and_format(struct llama_model * model, std::vector<l
return formatted;
}
int main(int argc, char ** argv) {
#ifdef __cplusplus
extern "C" {
#endif
#ifdef SHARED_LIB
int llama_cli_main(int argc, char ** argv);
int llama_cli_main(int argc, char ** argv)
#else
int main(int argc, char ** argv)
#endif
{
gpt_params params;
g_params = &params;
@ -533,7 +577,7 @@ int main(int argc, char ** argv) {
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
llama_fprintf(llama_stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}
@ -570,7 +614,7 @@ int main(int argc, char ** argv) {
console::set_display(console::error);
printf("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
console::set_display(console::reset);
fflush(stdout);
llama_fflush(llama_stdout);
}
if (ga_n == 1) {
@ -770,7 +814,7 @@ int main(int argc, char ** argv) {
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
// Console/Stream Output
fprintf(stdout, "%s", token_str.c_str());
llama_fprintf(llama_stdout, "%s", token_str.c_str());
// Record Displayed Tokens To Log
// Note: Generated tokens are created one by one hence this check
@ -783,7 +827,7 @@ int main(int argc, char ** argv) {
output_ss << token_str;
}
fflush(stdout);
llama_fflush(llama_stdout);
}
}
@ -995,3 +1039,7 @@ int main(int argc, char ** argv) {
return 0;
}
#ifdef __cplusplus
}
#endif