Opt class for positional argument handling (#10508)
Added support for positional arguments `model` and `prompt`. Added functionality to download via strings like: llama-run llama3 llama-run ollama://granite-code llama-run ollama://granite-code:8b llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf llama-run https://example.com/some-file1.gguf llama-run some-file2.gguf llama-run file://some-file3.gguf Signed-off-by: Eric Curtin <ecurtin@redhat.com>
This commit is contained in:
		
							parent
							
								
									11e07fd63b
								
							
						
					
					
						commit
						c27ac678dd
					
				
					 7 changed files with 542 additions and 163 deletions
				
			
		
							
								
								
									
										14
									
								
								README.md
									
										
									
									
									
								
							
							
						
						
									
										14
									
								
								README.md
									
										
									
									
									
								
							|  | @ -433,6 +433,20 @@ To learn more about model quantization, [read this documentation](examples/quant | ||||||
| 
 | 
 | ||||||
|     </details> |     </details> | ||||||
| 
 | 
 | ||||||
|  | ## [`llama-run`](examples/run) | ||||||
|  | 
 | ||||||
|  | #### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^3]. | ||||||
|  | 
 | ||||||
|  | - <details> | ||||||
|  |     <summary>Run a model with a specific prompt (by default it's pulled from Ollama registry)</summary> | ||||||
|  | 
 | ||||||
|  |     ```bash | ||||||
|  |     llama-run granite-code | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     </details> | ||||||
|  | 
 | ||||||
|  | [^3]: [https://github.com/containers/ramalama](RamaLama) | ||||||
| 
 | 
 | ||||||
| ## [`llama-simple`](examples/simple) | ## [`llama-simple`](examples/simple) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -81,7 +81,7 @@ set(LLAMA_COMMON_EXTRA_LIBS build_info) | ||||||
| # Use curl to download model url | # Use curl to download model url | ||||||
| if (LLAMA_CURL) | if (LLAMA_CURL) | ||||||
|     find_package(CURL REQUIRED) |     find_package(CURL REQUIRED) | ||||||
|     add_definitions(-DLLAMA_USE_CURL) |     target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL) | ||||||
|     include_directories(${CURL_INCLUDE_DIRS}) |     include_directories(${CURL_INCLUDE_DIRS}) | ||||||
|     find_library(CURL_LIBRARY curl REQUIRED) |     find_library(CURL_LIBRARY curl REQUIRED) | ||||||
|     set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY}) |     set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY}) | ||||||
|  |  | ||||||
|  | @ -1076,12 +1076,6 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p | ||||||
| #define CURL_MAX_RETRY 3 | #define CURL_MAX_RETRY 3 | ||||||
| #define CURL_RETRY_DELAY_SECONDS 2 | #define CURL_RETRY_DELAY_SECONDS 2 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| static bool starts_with(const std::string & str, const std::string & prefix) { |  | ||||||
|     // While we wait for C++20's std::string::starts_with...
 |  | ||||||
|     return str.rfind(prefix, 0) == 0; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) { | static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) { | ||||||
|     int remaining_attempts = max_attempts; |     int remaining_attempts = max_attempts; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -37,9 +37,9 @@ using llama_tokens = std::vector<llama_token>; | ||||||
| 
 | 
 | ||||||
| // build info
 | // build info
 | ||||||
| extern int LLAMA_BUILD_NUMBER; | extern int LLAMA_BUILD_NUMBER; | ||||||
| extern char const * LLAMA_COMMIT; | extern const char * LLAMA_COMMIT; | ||||||
| extern char const * LLAMA_COMPILER; | extern const char * LLAMA_COMPILER; | ||||||
| extern char const * LLAMA_BUILD_TARGET; | extern const char * LLAMA_BUILD_TARGET; | ||||||
| 
 | 
 | ||||||
| struct common_control_vector_load_info; | struct common_control_vector_load_info; | ||||||
| 
 | 
 | ||||||
|  | @ -437,6 +437,11 @@ std::vector<std::string> string_split<std::string>(const std::string & input, ch | ||||||
|     return parts; |     return parts; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | static bool string_starts_with(const std::string & str, | ||||||
|  |                                const std::string & prefix) {  // While we wait for C++20's std::string::starts_with...
 | ||||||
|  |     return str.rfind(prefix, 0) == 0; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); | bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); | ||||||
| void string_process_escapes(std::string & input); | void string_process_escapes(std::string & input); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| set(TARGET llama-run) | set(TARGET llama-run) | ||||||
| add_executable(${TARGET} run.cpp) | add_executable(${TARGET} run.cpp) | ||||||
| install(TARGETS ${TARGET} RUNTIME) | install(TARGETS ${TARGET} RUNTIME) | ||||||
| target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) | target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) | ||||||
| target_compile_features(${TARGET} PRIVATE cxx_std_17) | target_compile_features(${TARGET} PRIVATE cxx_std_17) | ||||||
|  |  | ||||||
|  | @ -3,5 +3,45 @@ | ||||||
| The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models. | The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models. | ||||||
| 
 | 
 | ||||||
| ```bash | ```bash | ||||||
| ./llama-run Meta-Llama-3.1-8B-Instruct.gguf | llama-run granite-code | ||||||
|  | ... | ||||||
|  | 
 | ||||||
|  | ```bash | ||||||
|  | llama-run -h | ||||||
|  | Description: | ||||||
|  |   Runs a llm | ||||||
|  | 
 | ||||||
|  | Usage: | ||||||
|  |   llama-run [options] model [prompt] | ||||||
|  | 
 | ||||||
|  | Options: | ||||||
|  |   -c, --context-size <value> | ||||||
|  |       Context size (default: 2048) | ||||||
|  |   -n, --ngl <value> | ||||||
|  |       Number of GPU layers (default: 0) | ||||||
|  |   -h, --help | ||||||
|  |       Show help message | ||||||
|  | 
 | ||||||
|  | Commands: | ||||||
|  |   model | ||||||
|  |       Model is a string with an optional prefix of | ||||||
|  |       huggingface:// (hf://), ollama://, https:// or file://. | ||||||
|  |       If no protocol is specified and a file exists in the specified | ||||||
|  |       path, file:// is assumed, otherwise if a file does not exist in | ||||||
|  |       the specified path, ollama:// is assumed. Models that are being | ||||||
|  |       pulled are downloaded with .partial extension while being | ||||||
|  |       downloaded and then renamed as the file without the .partial | ||||||
|  |       extension when complete. | ||||||
|  | 
 | ||||||
|  | Examples: | ||||||
|  |   llama-run llama3 | ||||||
|  |   llama-run ollama://granite-code | ||||||
|  |   llama-run ollama://smollm:135m | ||||||
|  |   llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf | ||||||
|  |   llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf | ||||||
|  |   llama-run https://example.com/some-file1.gguf | ||||||
|  |   llama-run some-file2.gguf | ||||||
|  |   llama-run file://some-file3.gguf | ||||||
|  |   llama-run --ngl 99 some-file4.gguf | ||||||
|  |   llama-run --ngl 99 some-file5.gguf Hello World | ||||||
| ... | ... | ||||||
|  |  | ||||||
|  | @ -4,110 +4,330 @@ | ||||||
| #    include <unistd.h> | #    include <unistd.h> | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| #include <climits> | #if defined(LLAMA_USE_CURL) | ||||||
|  | #    include <curl/curl.h> | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | #include <cstdarg> | ||||||
| #include <cstdio> | #include <cstdio> | ||||||
| #include <cstring> | #include <cstring> | ||||||
|  | #include <filesystem> | ||||||
| #include <iostream> | #include <iostream> | ||||||
| #include <sstream> | #include <sstream> | ||||||
| #include <string> | #include <string> | ||||||
| #include <unordered_map> |  | ||||||
| #include <vector> | #include <vector> | ||||||
| 
 | 
 | ||||||
|  | #include "common.h" | ||||||
|  | #include "json.hpp" | ||||||
| #include "llama-cpp.h" | #include "llama-cpp.h" | ||||||
| 
 | 
 | ||||||
| typedef std::unique_ptr<char[]> char_array_ptr; | #define printe(...)                   \ | ||||||
|  |     do {                              \ | ||||||
|  |         fprintf(stderr, __VA_ARGS__); \ | ||||||
|  |     } while (0) | ||||||
| 
 | 
 | ||||||
| struct Argument { | class Opt { | ||||||
|     std::string flag; |  | ||||||
|     std::string help_text; |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| struct Options { |  | ||||||
|     std::string model_path, prompt_non_interactive; |  | ||||||
|     int ngl = 99; |  | ||||||
|     int n_ctx = 2048; |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| class ArgumentParser { |  | ||||||
|   public: |   public: | ||||||
|     ArgumentParser(const char * program_name) : program_name(program_name) {} |     int init(int argc, const char ** argv) { | ||||||
| 
 |         construct_help_str_(); | ||||||
|     void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") { |         // Parse arguments
 | ||||||
|         string_args[flag] = &var; |         if (parse(argc, argv)) { | ||||||
|         arguments.push_back({flag, help_text}); |             printe("Error: Failed to parse arguments.\n"); | ||||||
|  |             help(); | ||||||
|  |             return 1; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     void add_argument(const std::string & flag, int & var, const std::string & help_text = "") { |         // If help is requested, show help and exit
 | ||||||
|         int_args[flag] = &var; |         if (help_) { | ||||||
|         arguments.push_back({flag, help_text}); |             help(); | ||||||
|  |             return 2; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         return 0;  // Success
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     std::string model_; | ||||||
|  |     std::string user_; | ||||||
|  |     int         context_size_ = 2048, ngl_ = -1; | ||||||
|  | 
 | ||||||
|  |   private: | ||||||
|  |     std::string help_str_; | ||||||
|  |     bool        help_ = false; | ||||||
|  | 
 | ||||||
|  |     void construct_help_str_() { | ||||||
|  |         help_str_ = | ||||||
|  |             "Description:\n" | ||||||
|  |             "  Runs a llm\n" | ||||||
|  |             "\n" | ||||||
|  |             "Usage:\n" | ||||||
|  |             "  llama-run [options] model [prompt]\n" | ||||||
|  |             "\n" | ||||||
|  |             "Options:\n" | ||||||
|  |             "  -c, --context-size <value>\n" | ||||||
|  |             "      Context size (default: " + | ||||||
|  |             std::to_string(context_size_); | ||||||
|  |         help_str_ += | ||||||
|  |             ")\n" | ||||||
|  |             "  -n, --ngl <value>\n" | ||||||
|  |             "      Number of GPU layers (default: " + | ||||||
|  |             std::to_string(ngl_); | ||||||
|  |         help_str_ += | ||||||
|  |             ")\n" | ||||||
|  |             "  -h, --help\n" | ||||||
|  |             "      Show help message\n" | ||||||
|  |             "\n" | ||||||
|  |             "Commands:\n" | ||||||
|  |             "  model\n" | ||||||
|  |             "      Model is a string with an optional prefix of \n" | ||||||
|  |             "      huggingface:// (hf://), ollama://, https:// or file://.\n" | ||||||
|  |             "      If no protocol is specified and a file exists in the specified\n" | ||||||
|  |             "      path, file:// is assumed, otherwise if a file does not exist in\n" | ||||||
|  |             "      the specified path, ollama:// is assumed. Models that are being\n" | ||||||
|  |             "      pulled are downloaded with .partial extension while being\n" | ||||||
|  |             "      downloaded and then renamed as the file without the .partial\n" | ||||||
|  |             "      extension when complete.\n" | ||||||
|  |             "\n" | ||||||
|  |             "Examples:\n" | ||||||
|  |             "  llama-run llama3\n" | ||||||
|  |             "  llama-run ollama://granite-code\n" | ||||||
|  |             "  llama-run ollama://smollm:135m\n" | ||||||
|  |             "  llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" | ||||||
|  |             "  llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" | ||||||
|  |             "  llama-run https://example.com/some-file1.gguf\n" | ||||||
|  |             "  llama-run some-file2.gguf\n" | ||||||
|  |             "  llama-run file://some-file3.gguf\n" | ||||||
|  |             "  llama-run --ngl 99 some-file4.gguf\n" | ||||||
|  |             "  llama-run --ngl 99 some-file5.gguf Hello World\n"; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     int parse(int argc, const char ** argv) { |     int parse(int argc, const char ** argv) { | ||||||
|  |         int positional_args_i = 0; | ||||||
|         for (int i = 1; i < argc; ++i) { |         for (int i = 1; i < argc; ++i) { | ||||||
|             std::string arg = argv[i]; |             if (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0) { | ||||||
|             if (string_args.count(arg)) { |                 if (i + 1 >= argc) { | ||||||
|                 if (i + 1 < argc) { |  | ||||||
|                     *string_args[arg] = argv[++i]; |  | ||||||
|                 } else { |  | ||||||
|                     fprintf(stderr, "error: missing value for %s\n", arg.c_str()); |  | ||||||
|                     print_usage(); |  | ||||||
|                     return 1; |                     return 1; | ||||||
|                 } |                 } | ||||||
|             } else if (int_args.count(arg)) { | 
 | ||||||
|                 if (i + 1 < argc) { |                 context_size_ = std::atoi(argv[++i]); | ||||||
|                     if (parse_int_arg(argv[++i], *int_args[arg]) != 0) { |             } else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0) { | ||||||
|                         fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]); |                 if (i + 1 >= argc) { | ||||||
|                         print_usage(); |  | ||||||
|                     return 1; |                     return 1; | ||||||
|                 } |                 } | ||||||
|  | 
 | ||||||
|  |                 ngl_ = std::atoi(argv[++i]); | ||||||
|  |             } else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { | ||||||
|  |                 help_ = true; | ||||||
|  |                 return 0; | ||||||
|  |             } else if (!positional_args_i) { | ||||||
|  |                 ++positional_args_i; | ||||||
|  |                 model_ = argv[i]; | ||||||
|  |             } else if (positional_args_i == 1) { | ||||||
|  |                 ++positional_args_i; | ||||||
|  |                 user_ = argv[i]; | ||||||
|             } else { |             } else { | ||||||
|                     fprintf(stderr, "error: missing value for %s\n", arg.c_str()); |                 user_ += " " + std::string(argv[i]); | ||||||
|                     print_usage(); |  | ||||||
|                     return 1; |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str()); |  | ||||||
|                 print_usage(); |  | ||||||
|                 return 1; |  | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         if (string_args["-m"]->empty()) { |         return model_.empty();  // model_ is the only required value
 | ||||||
|             fprintf(stderr, "error: -m is required\n"); |     } | ||||||
|             print_usage(); | 
 | ||||||
|  |     void help() const { printf("%s", help_str_.c_str()); } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | struct progress_data { | ||||||
|  |     size_t file_size = 0; | ||||||
|  |     std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); | ||||||
|  |     bool   printed   = false; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | struct FileDeleter { | ||||||
|  |     void operator()(FILE * file) const { | ||||||
|  |         if (file) { | ||||||
|  |             fclose(file); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | typedef std::unique_ptr<FILE, FileDeleter> FILE_ptr; | ||||||
|  | 
 | ||||||
|  | #ifdef LLAMA_USE_CURL | ||||||
|  | class CurlWrapper { | ||||||
|  |   public: | ||||||
|  |     int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file, | ||||||
|  |              const bool progress, std::string * response_str = nullptr) { | ||||||
|  |         std::string output_file_partial; | ||||||
|  |         curl = curl_easy_init(); | ||||||
|  |         if (!curl) { | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         progress_data data; | ||||||
|  |         FILE_ptr      out; | ||||||
|  |         if (!output_file.empty()) { | ||||||
|  |             output_file_partial = output_file + ".partial"; | ||||||
|  |             out.reset(fopen(output_file_partial.c_str(), "ab")); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         set_write_options(response_str, out); | ||||||
|  |         data.file_size = set_resume_point(output_file_partial); | ||||||
|  |         set_progress_options(progress, data); | ||||||
|  |         set_headers(headers); | ||||||
|  |         perform(url); | ||||||
|  |         if (!output_file.empty()) { | ||||||
|  |             std::filesystem::rename(output_file_partial, output_file); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         return 0; |         return 0; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     ~CurlWrapper() { | ||||||
|  |         if (chunk) { | ||||||
|  |             curl_slist_free_all(chunk); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if (curl) { | ||||||
|  |             curl_easy_cleanup(curl); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|   private: |   private: | ||||||
|     const char * program_name; |     CURL *              curl  = nullptr; | ||||||
|     std::unordered_map<std::string, std::string *> string_args; |     struct curl_slist * chunk = nullptr; | ||||||
|     std::unordered_map<std::string, int *> int_args; |  | ||||||
|     std::vector<Argument> arguments; |  | ||||||
| 
 | 
 | ||||||
|     int parse_int_arg(const char * arg, int & value) { |     void set_write_options(std::string * response_str, const FILE_ptr & out) { | ||||||
|         char * end; |         if (response_str) { | ||||||
|         const long val = std::strtol(arg, &end, 10); |             curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data); | ||||||
|         if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) { |             curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str); | ||||||
|             value = static_cast<int>(val); |         } else { | ||||||
|  |             curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); | ||||||
|  |             curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.get()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     size_t set_resume_point(const std::string & output_file) { | ||||||
|  |         size_t file_size = 0; | ||||||
|  |         if (std::filesystem::exists(output_file)) { | ||||||
|  |             file_size = std::filesystem::file_size(output_file); | ||||||
|  |             curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast<curl_off_t>(file_size)); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         return file_size; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     void set_progress_options(bool progress, progress_data & data) { | ||||||
|  |         if (progress) { | ||||||
|  |             curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); | ||||||
|  |             curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data); | ||||||
|  |             curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, progress_callback); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     void set_headers(const std::vector<std::string> & headers) { | ||||||
|  |         if (!headers.empty()) { | ||||||
|  |             if (chunk) { | ||||||
|  |                 curl_slist_free_all(chunk); | ||||||
|  |                 chunk = 0; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             for (const auto & header : headers) { | ||||||
|  |                 chunk = curl_slist_append(chunk, header.c_str()); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     void perform(const std::string & url) { | ||||||
|  |         CURLcode res; | ||||||
|  |         curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); | ||||||
|  |         curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); | ||||||
|  |         curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https"); | ||||||
|  |         curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); | ||||||
|  |         res = curl_easy_perform(curl); | ||||||
|  |         if (res != CURLE_OK) { | ||||||
|  |             printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static std::string human_readable_time(double seconds) { | ||||||
|  |         int hrs  = static_cast<int>(seconds) / 3600; | ||||||
|  |         int mins = (static_cast<int>(seconds) % 3600) / 60; | ||||||
|  |         int secs = static_cast<int>(seconds) % 60; | ||||||
|  | 
 | ||||||
|  |         std::ostringstream out; | ||||||
|  |         if (hrs > 0) { | ||||||
|  |             out << hrs << "h " << std::setw(2) << std::setfill('0') << mins << "m " << std::setw(2) << std::setfill('0') | ||||||
|  |                 << secs << "s"; | ||||||
|  |         } else if (mins > 0) { | ||||||
|  |             out << mins << "m " << std::setw(2) << std::setfill('0') << secs << "s"; | ||||||
|  |         } else { | ||||||
|  |             out << secs << "s"; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         return out.str(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static std::string human_readable_size(curl_off_t size) { | ||||||
|  |         static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" }; | ||||||
|  |         char         length   = sizeof(suffix) / sizeof(suffix[0]); | ||||||
|  |         int          i        = 0; | ||||||
|  |         double       dbl_size = size; | ||||||
|  |         if (size > 1024) { | ||||||
|  |             for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) { | ||||||
|  |                 dbl_size = size / 1024.0; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         std::ostringstream out; | ||||||
|  |         out << std::fixed << std::setprecision(2) << dbl_size << " " << suffix[i]; | ||||||
|  |         return out.str(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static int progress_callback(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t, | ||||||
|  |                                  curl_off_t) { | ||||||
|  |         progress_data * data = static_cast<progress_data *>(ptr); | ||||||
|  |         if (total_to_download <= 0) { | ||||||
|             return 0; |             return 0; | ||||||
|         } |         } | ||||||
|         return 1; | 
 | ||||||
|  |         total_to_download += data->file_size; | ||||||
|  |         const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size; | ||||||
|  |         const curl_off_t percentage                    = (now_downloaded_plus_file_size * 100) / total_to_download; | ||||||
|  |         const curl_off_t pos                           = (percentage / 5); | ||||||
|  |         std::string progress_bar; | ||||||
|  |         for (int i = 0; i < 20; ++i) { | ||||||
|  |             progress_bar.append((i < pos) ? "█" : " "); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     void print_usage() const { |         // Calculate download speed and estimated time to completion
 | ||||||
|         printf("\nUsage:\n"); |         const auto                          now             = std::chrono::steady_clock::now(); | ||||||
|         printf("  %s [OPTIONS]\n\n", program_name); |         const std::chrono::duration<double> elapsed_seconds = now - data->start_time; | ||||||
|         printf("Options:\n"); |         const double                        speed           = now_downloaded / elapsed_seconds.count(); | ||||||
|         for (const auto & arg : arguments) { |         const double                        estimated_time  = (total_to_download - now_downloaded) / speed; | ||||||
|             printf("  %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str()); |         printe("\r%ld%% |%s| %s/%s  %.2f MB/s  %s      ", percentage, progress_bar.c_str(), | ||||||
|  |                human_readable_size(now_downloaded).c_str(), human_readable_size(total_to_download).c_str(), | ||||||
|  |                speed / (1024 * 1024), human_readable_time(estimated_time).c_str()); | ||||||
|  |         fflush(stderr); | ||||||
|  |         data->printed = true; | ||||||
|  | 
 | ||||||
|  |         return 0; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|         printf("\n"); |     // Function to write data to a file
 | ||||||
|  |     static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) { | ||||||
|  |         FILE * out = static_cast<FILE *>(stream); | ||||||
|  |         return fwrite(ptr, size, nmemb, out); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Function to capture data into a string
 | ||||||
|  |     static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) { | ||||||
|  |         std::string * str = static_cast<std::string *>(stream); | ||||||
|  |         str->append(static_cast<char *>(ptr), size * nmemb); | ||||||
|  |         return size * nmemb; | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  | #endif | ||||||
| 
 | 
 | ||||||
| class LlamaData { | class LlamaData { | ||||||
|   public: |   public: | ||||||
|  | @ -115,14 +335,16 @@ class LlamaData { | ||||||
|     llama_sampler_ptr               sampler; |     llama_sampler_ptr               sampler; | ||||||
|     llama_context_ptr               context; |     llama_context_ptr               context; | ||||||
|     std::vector<llama_chat_message> messages; |     std::vector<llama_chat_message> messages; | ||||||
|  |     std::vector<std::string>        msg_strs; | ||||||
|  |     std::vector<char>               fmtted; | ||||||
| 
 | 
 | ||||||
|     int init(const Options & opt) { |     int init(Opt & opt) { | ||||||
|         model = initialize_model(opt.model_path, opt.ngl); |         model = initialize_model(opt); | ||||||
|         if (!model) { |         if (!model) { | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         context = initialize_context(model, opt.n_ctx); |         context = initialize_context(model, opt.context_size_); | ||||||
|         if (!context) { |         if (!context) { | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|  | @ -132,14 +354,122 @@ class LlamaData { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|     // Initializes the model and returns a unique pointer to it
 | #ifdef LLAMA_USE_CURL | ||||||
|     llama_model_ptr initialize_model(const std::string & model_path, const int ngl) { |     int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file, | ||||||
|         llama_model_params model_params = llama_model_default_params(); |                  const bool progress, std::string * response_str = nullptr) { | ||||||
|         model_params.n_gpu_layers = ngl; |         CurlWrapper curl; | ||||||
|  |         if (curl.init(url, headers, output_file, progress, response_str)) { | ||||||
|  |             return 1; | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         llama_model_ptr model(llama_load_model_from_file(model_path.c_str(), model_params)); |         return 0; | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |     int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool, | ||||||
|  |                  std::string * = nullptr) { | ||||||
|  |         printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); | ||||||
|  |         return 1; | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  |     int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) { | ||||||
|  |         // Find the second occurrence of '/' after protocol string
 | ||||||
|  |         size_t pos = model.find('/'); | ||||||
|  |         pos        = model.find('/', pos + 1); | ||||||
|  |         if (pos == std::string::npos) { | ||||||
|  |             return 1; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         const std::string hfr = model.substr(0, pos); | ||||||
|  |         const std::string hff = model.substr(pos + 1); | ||||||
|  |         const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; | ||||||
|  |         return download(url, headers, bn, true); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) { | ||||||
|  |         if (model.find('/') == std::string::npos) { | ||||||
|  |             model = "library/" + model; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         std::string model_tag = "latest"; | ||||||
|  |         size_t      colon_pos = model.find(':'); | ||||||
|  |         if (colon_pos != std::string::npos) { | ||||||
|  |             model_tag = model.substr(colon_pos + 1); | ||||||
|  |             model     = model.substr(0, colon_pos); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag; | ||||||
|  |         std::string manifest_str; | ||||||
|  |         const int   ret = download(manifest_url, headers, "", false, &manifest_str); | ||||||
|  |         if (ret) { | ||||||
|  |             return ret; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         nlohmann::json manifest = nlohmann::json::parse(manifest_str); | ||||||
|  |         std::string    layer; | ||||||
|  |         for (const auto & l : manifest["layers"]) { | ||||||
|  |             if (l["mediaType"] == "application/vnd.ollama.image.model") { | ||||||
|  |                 layer = l["digest"]; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer; | ||||||
|  |         return download(blob_url, headers, bn, true); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     std::string basename(const std::string & path) { | ||||||
|  |         const size_t pos = path.find_last_of("/\\"); | ||||||
|  |         if (pos == std::string::npos) { | ||||||
|  |             return path; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         return path.substr(pos + 1); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     int remove_proto(std::string & model_) { | ||||||
|  |         const std::string::size_type pos = model_.find("://"); | ||||||
|  |         if (pos == std::string::npos) { | ||||||
|  |             return 1; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         model_ = model_.substr(pos + 3);  // Skip past "://"
 | ||||||
|  |         return 0; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     int resolve_model(std::string & model_) { | ||||||
|  |         const std::string              bn      = basename(model_); | ||||||
|  |         const std::vector<std::string> headers = { "--header", | ||||||
|  |                                                    "Accept: application/vnd.docker.distribution.manifest.v2+json" }; | ||||||
|  |         int                            ret     = 0; | ||||||
|  |         if (string_starts_with(model_, "file://") || std::filesystem::exists(bn)) { | ||||||
|  |             remove_proto(model_); | ||||||
|  |         } else if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) { | ||||||
|  |             remove_proto(model_); | ||||||
|  |             ret = huggingface_dl(model_, headers, bn); | ||||||
|  |         } else if (string_starts_with(model_, "ollama://")) { | ||||||
|  |             remove_proto(model_); | ||||||
|  |             ret = ollama_dl(model_, headers, bn); | ||||||
|  |         } else if (string_starts_with(model_, "https://")) { | ||||||
|  |             download(model_, headers, bn, true); | ||||||
|  |         } else { | ||||||
|  |             ret = ollama_dl(model_, headers, bn); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         model_ = bn; | ||||||
|  | 
 | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Initializes the model and returns a unique pointer to it
 | ||||||
|  |     llama_model_ptr initialize_model(Opt & opt) { | ||||||
|  |         ggml_backend_load_all(); | ||||||
|  |         llama_model_params model_params = llama_model_default_params(); | ||||||
|  |         model_params.n_gpu_layers       = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers; | ||||||
|  |         resolve_model(opt.model_); | ||||||
|  |         llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params)); | ||||||
|         if (!model) { |         if (!model) { | ||||||
|             fprintf(stderr, "%s: error: unable to load model\n", __func__); |             printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str()); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         return model; |         return model; | ||||||
|  | @ -150,10 +480,9 @@ class LlamaData { | ||||||
|         llama_context_params ctx_params = llama_context_default_params(); |         llama_context_params ctx_params = llama_context_default_params(); | ||||||
|         ctx_params.n_ctx                = n_ctx; |         ctx_params.n_ctx                = n_ctx; | ||||||
|         ctx_params.n_batch              = n_ctx; |         ctx_params.n_batch              = n_ctx; | ||||||
| 
 |  | ||||||
|         llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params)); |         llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params)); | ||||||
|         if (!context) { |         if (!context) { | ||||||
|             fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__); |             printe("%s: error: failed to create the llama_context\n", __func__); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         return context; |         return context; | ||||||
|  | @ -170,23 +499,22 @@ class LlamaData { | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // Add a message to `messages` and store its content in `owned_content`
 | // Add a message to `messages` and store its content in `msg_strs`
 | ||||||
| static void add_message(const char * role, const std::string & text, LlamaData & llama_data, | static void add_message(const char * role, const std::string & text, LlamaData & llama_data) { | ||||||
|                         std::vector<char_array_ptr> & owned_content) { |     llama_data.msg_strs.push_back(std::move(text)); | ||||||
|     char_array_ptr content(new char[text.size() + 1]); |     llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); | ||||||
|     std::strcpy(content.get(), text.c_str()); |  | ||||||
|     llama_data.messages.push_back({role, content.get()}); |  | ||||||
|     owned_content.push_back(std::move(content)); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Function to apply the chat template and resize `formatted` if needed
 | // Function to apply the chat template and resize `formatted` if needed
 | ||||||
| static int apply_chat_template(const LlamaData & llama_data, std::vector<char> & formatted, const bool append) { | static int apply_chat_template(LlamaData & llama_data, const bool append) { | ||||||
|     int result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), |     int result = llama_chat_apply_template( | ||||||
|                                            llama_data.messages.size(), append, formatted.data(), formatted.size()); |         llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append, | ||||||
|     if (result > static_cast<int>(formatted.size())) { |         append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); | ||||||
|         formatted.resize(result); |     if (append && result > static_cast<int>(llama_data.fmtted.size())) { | ||||||
|  |         llama_data.fmtted.resize(result); | ||||||
|         result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), |         result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), | ||||||
|                                            llama_data.messages.size(), append, formatted.data(), formatted.size()); |                                            llama_data.messages.size(), append, llama_data.fmtted.data(), | ||||||
|  |                                            llama_data.fmtted.size()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     return result; |     return result; | ||||||
|  | @ -199,7 +527,8 @@ static int tokenize_prompt(const llama_model_ptr & model, const std::string & pr | ||||||
|     prompt_tokens.resize(n_prompt_tokens); |     prompt_tokens.resize(n_prompt_tokens); | ||||||
|     if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, |     if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, | ||||||
|                        true) < 0) { |                        true) < 0) { | ||||||
|         GGML_ABORT("failed to tokenize the prompt\n"); |         printe("failed to tokenize the prompt\n"); | ||||||
|  |         return -1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     return n_prompt_tokens; |     return n_prompt_tokens; | ||||||
|  | @ -211,7 +540,7 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch & | ||||||
|     const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get()); |     const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get()); | ||||||
|     if (n_ctx_used + batch.n_tokens > n_ctx) { |     if (n_ctx_used + batch.n_tokens > n_ctx) { | ||||||
|         printf("\033[0m\n"); |         printf("\033[0m\n"); | ||||||
|         fprintf(stderr, "context size exceeded\n"); |         printe("context size exceeded\n"); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -223,7 +552,8 @@ static int convert_token_to_string(const llama_model_ptr & model, const llama_to | ||||||
|     char buf[256]; |     char buf[256]; | ||||||
|     int  n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); |     int  n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); | ||||||
|     if (n < 0) { |     if (n < 0) { | ||||||
|         GGML_ABORT("failed to convert token to piece\n"); |         printe("failed to convert token to piece\n"); | ||||||
|  |         return 1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     piece = std::string(buf, n); |     piece = std::string(buf, n); | ||||||
|  | @ -238,19 +568,19 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st | ||||||
| 
 | 
 | ||||||
| // helper function to evaluate a prompt and generate a response
 | // helper function to evaluate a prompt and generate a response
 | ||||||
| static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { | static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { | ||||||
|     std::vector<llama_token> prompt_tokens; |     std::vector<llama_token> tokens; | ||||||
|     const int n_prompt_tokens = tokenize_prompt(llama_data.model, prompt, prompt_tokens); |     if (tokenize_prompt(llama_data.model, prompt, tokens) < 0) { | ||||||
|     if (n_prompt_tokens < 0) { |  | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // prepare a batch for the prompt
 |     // prepare a batch for the prompt
 | ||||||
|     llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); |     llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); | ||||||
|     llama_token new_token_id; |     llama_token new_token_id; | ||||||
|     while (true) { |     while (true) { | ||||||
|         check_context_size(llama_data.context, batch); |         check_context_size(llama_data.context, batch); | ||||||
|         if (llama_decode(llama_data.context.get(), batch)) { |         if (llama_decode(llama_data.context.get(), batch)) { | ||||||
|             GGML_ABORT("failed to decode\n"); |             printe("failed to decode\n"); | ||||||
|  |             return 1; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // sample the next token, check is it an end of generation?
 |         // sample the next token, check is it an end of generation?
 | ||||||
|  | @ -273,22 +603,9 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| static int parse_arguments(const int argc, const char ** argv, Options & opt) { |  | ||||||
|     ArgumentParser parser(argv[0]); |  | ||||||
|     parser.add_argument("-m", opt.model_path, "model"); |  | ||||||
|     parser.add_argument("-p", opt.prompt_non_interactive, "prompt"); |  | ||||||
|     parser.add_argument("-c", opt.n_ctx, "context_size"); |  | ||||||
|     parser.add_argument("-ngl", opt.ngl, "n_gpu_layers"); |  | ||||||
|     if (parser.parse(argc, argv)) { |  | ||||||
|         return 1; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     return 0; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| static int read_user_input(std::string & user) { | static int read_user_input(std::string & user) { | ||||||
|     std::getline(std::cin, user); |     std::getline(std::cin, user); | ||||||
|     return user.empty();  // Indicate an error or empty input
 |     return user.empty();  // Should have data in happy path
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Function to generate a response based on the prompt
 | // Function to generate a response based on the prompt
 | ||||||
|  | @ -296,7 +613,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, | ||||||
|     // Set response color
 |     // Set response color
 | ||||||
|     printf("\033[33m"); |     printf("\033[33m"); | ||||||
|     if (generate(llama_data, prompt, response)) { |     if (generate(llama_data, prompt, response)) { | ||||||
|         fprintf(stderr, "failed to generate response\n"); |         printe("failed to generate response\n"); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -306,11 +623,10 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Helper function to apply the chat template and handle errors
 | // Helper function to apply the chat template and handle errors
 | ||||||
| static int apply_chat_template_with_error_handling(const LlamaData & llama_data, std::vector<char> & formatted, | static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) { | ||||||
|                                                    const bool is_user_input, int & output_length) { |     const int new_len = apply_chat_template(llama_data, append); | ||||||
|     const int new_len = apply_chat_template(llama_data, formatted, is_user_input); |  | ||||||
|     if (new_len < 0) { |     if (new_len < 0) { | ||||||
|         fprintf(stderr, "failed to apply the chat template\n"); |         printe("failed to apply the chat template\n"); | ||||||
|         return -1; |         return -1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -319,49 +635,56 @@ static int apply_chat_template_with_error_handling(const LlamaData & llama_data, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Helper function to handle user input
 | // Helper function to handle user input
 | ||||||
| static bool handle_user_input(std::string & user_input, const std::string & prompt_non_interactive) { | static int handle_user_input(std::string & user_input, const std::string & user_) { | ||||||
|     if (!prompt_non_interactive.empty()) { |     if (!user_.empty()) { | ||||||
|         user_input = prompt_non_interactive; |         user_input = user_; | ||||||
|         return true;  // No need for interactive input
 |         return 0;  // No need for interactive input
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     printf("\033[32m> \033[0m"); |     printf( | ||||||
|     return !read_user_input(user_input);  // Returns false if input ends the loop
 |         "\r                                                                       " | ||||||
|  |         "\r\033[32m> \033[0m"); | ||||||
|  |     return read_user_input(user_input);  // Returns true if input ends the loop
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Function to tokenize the prompt
 | // Function to tokenize the prompt
 | ||||||
| static int chat_loop(LlamaData & llama_data, std::string & prompt_non_interactive) { | static int chat_loop(LlamaData & llama_data, const std::string & user_) { | ||||||
|     std::vector<char_array_ptr> owned_content; |  | ||||||
|     std::vector<char> fmtted(llama_n_ctx(llama_data.context.get())); |  | ||||||
|     int prev_len = 0; |     int prev_len = 0; | ||||||
| 
 |     llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); | ||||||
|     while (true) { |     while (true) { | ||||||
|         // Get user input
 |         // Get user input
 | ||||||
|         std::string user_input; |         std::string user_input; | ||||||
|         if (!handle_user_input(user_input, prompt_non_interactive)) { |         while (handle_user_input(user_input, user_)) { | ||||||
|             break; |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         add_message("user", prompt_non_interactive.empty() ? user_input : prompt_non_interactive, llama_data, |         add_message("user", user_.empty() ? user_input : user_, llama_data); | ||||||
|                     owned_content); |  | ||||||
| 
 |  | ||||||
|         int new_len; |         int new_len; | ||||||
|         if (apply_chat_template_with_error_handling(llama_data, fmtted, true, new_len) < 0) { |         if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) { | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         std::string prompt(fmtted.begin() + prev_len, fmtted.begin() + new_len); |         std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); | ||||||
|         std::string response; |         std::string response; | ||||||
|         if (generate_response(llama_data, prompt, response)) { |         if (generate_response(llama_data, prompt, response)) { | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|  | 
 | ||||||
|  |         if (!user_.empty()) { | ||||||
|  |             break; | ||||||
|         } |         } | ||||||
|  | 
 | ||||||
|  |         add_message("assistant", response, llama_data); | ||||||
|  |         if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) { | ||||||
|  |             return 1; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| static void log_callback(const enum ggml_log_level level, const char * text, void *) { | static void log_callback(const enum ggml_log_level level, const char * text, void *) { | ||||||
|     if (level == GGML_LOG_LEVEL_ERROR) { |     if (level == GGML_LOG_LEVEL_ERROR) { | ||||||
|         fprintf(stderr, "%s", text); |         printe("%s", text); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -382,17 +705,20 @@ static std::string read_pipe_data() { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| int main(int argc, const char ** argv) { | int main(int argc, const char ** argv) { | ||||||
|     Options opt; |     Opt       opt; | ||||||
|     if (parse_arguments(argc, argv, opt)) { |     const int ret = opt.init(argc, argv); | ||||||
|  |     if (ret == 2) { | ||||||
|  |         return 0; | ||||||
|  |     } else if (ret) { | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     if (!is_stdin_a_terminal()) { |     if (!is_stdin_a_terminal()) { | ||||||
|         if (!opt.prompt_non_interactive.empty()) { |         if (!opt.user_.empty()) { | ||||||
|             opt.prompt_non_interactive += "\n\n"; |             opt.user_ += "\n\n"; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         opt.prompt_non_interactive += read_pipe_data(); |         opt.user_ += read_pipe_data(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     llama_log_set(log_callback, nullptr); |     llama_log_set(log_callback, nullptr); | ||||||
|  | @ -401,7 +727,7 @@ int main(int argc, const char ** argv) { | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     if (chat_loop(llama_data, opt.prompt_non_interactive)) { |     if (chat_loop(llama_data, opt.user_)) { | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue