diff --git a/ggml-internal.hpp b/ggml-internal.hpp index 29ae01198..0725451fc 100644 --- a/ggml-internal.hpp +++ b/ggml-internal.hpp @@ -90,3 +90,169 @@ struct ggml_allocr { ggml_tallocr_t talloc; ggml_gallocr_t galloc; }; + +#define GGML_NUMA_MAX_NODES 8 +#define GGML_NUMA_MAX_CPUS 512 + +struct ggml_numa_node { + uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node + uint32_t n_cpus; +}; + +struct ggml_numa_nodes { + struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES]; + uint32_t n_nodes; + uint32_t total_cpus; // hardware threads on system +}; + +struct ggml_state { + struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; + struct ggml_numa_nodes numa; + + ggml_state():contexts(), numa() + { + + } +}; + +struct gguf_str { + uint64_t n; // GGUFv2 + char * data; +}; + +struct ggml_map_custom1_op_params { + ggml_custom1_op_t fun; + int n_tasks; + void * userdata; +}; + +struct ggml_map_custom2_op_params { + ggml_custom2_op_t fun; + int n_tasks; + void * userdata; +}; + +struct ggml_map_custom3_op_params { + ggml_custom3_op_t fun; + int n_tasks; + void * userdata; +}; +struct hash_map { + struct ggml_hash_set set; + struct ggml_tensor ** vals; +}; + +#if defined(_WIN32) +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; +#else +#include +using namespace std; +#endif + +struct ggml_compute_state_shared { + const struct ggml_cgraph * cgraph; + const struct ggml_cplan * cplan; + + int64_t perf_node_start_cycles; + int64_t perf_node_start_time_us; + + const int n_threads; + + // synchronization primitives + atomic_int n_active; // num active threads + atomic_int node_n; // active graph node + + bool (*abort_callback)(void * data); // abort ggml_graph_compute when true + void * abort_callback_data; +}; +typedef pthread_t ggml_thread_t; +struct ggml_compute_state { + ggml_thread_t thrd; + int ith; + struct ggml_compute_state_shared * shared; +}; + +union gguf_value { + uint8_t uint8; + int8_t int8; + uint16_t uint16; + int16_t int16; + uint32_t uint32; + int32_t int32; + float float32; + uint64_t uint64; + int64_t int64; + double float64; + bool bool_; + + struct gguf_str str; + + struct gguf_array_T { + enum gguf_type type; + + uint64_t n; // GGUFv2 + void * data; + } arr; +}; + +struct ggml_lbfgs_iteration_data { + float alpha; + float ys; + float * s; + float * y; +}; + +struct gguf_kv { + struct gguf_str key; + + enum gguf_type type; + union gguf_value value; +}; + + + +struct gguf_header { + char magic[4]; + uint32_t version; + uint64_t n_tensors; // GGUFv2 + uint64_t n_kv; // GGUFv2 +}; + +struct gguf_tensor_info { + struct gguf_str name; + + uint32_t n_dims; + uint64_t ne[GGML_MAX_DIMS]; + + enum ggml_type type; + + uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` + + // for writing API + const void * data; + size_t size; +}; + +struct gguf_context { + struct gguf_header header; + + struct gguf_kv * kv; + struct gguf_tensor_info * infos; + + size_t alignment; + size_t offset; // offset of `data` from beginning of file + size_t size; // size of `data` in bytes + + //uint8_t * padding; + void * data; +}; + +struct gguf_buf { + void * data; + size_t size; + size_t offset; +}; + + +#include "ggml-backend-impl.h" diff --git a/ggml.cpp b/ggml.cpp index 53e312ac3..9d2ab8ebd 100644 --- a/ggml.cpp +++ b/ggml.cpp @@ -1625,33 +1625,12 @@ static void ggml_setup_op_has_task_pass(void) { // NUMA support // -#define GGML_NUMA_MAX_NODES 8 -#define GGML_NUMA_MAX_CPUS 512 -struct ggml_numa_node { - uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node - uint32_t n_cpus; -}; - -struct ggml_numa_nodes { - struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES]; - uint32_t n_nodes; - uint32_t total_cpus; // hardware threads on system -}; // // ggml state // -struct ggml_state { - struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; - struct ggml_numa_nodes numa; - - ggml_state():contexts(), numa() - { - - } -}; // global state static struct ggml_state g_state; @@ -1986,10 +1965,6 @@ static inline int ggml_up(int n, int m) { //////////////////////////////////////////////////////////////////////////////// static size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT]={}; -struct gguf_str { - uint64_t n; // GGUFv2 - char * data; -}; static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {}; @@ -6084,11 +6059,6 @@ struct ggml_tensor * ggml_map_custom3_inplace_f32( } // ggml_map_custom1 -struct ggml_map_custom1_op_params { - ggml_custom1_op_t fun; - int n_tasks; - void * userdata; -}; static struct ggml_tensor * ggml_map_custom1_impl( struct ggml_context * ctx, @@ -6141,11 +6111,6 @@ struct ggml_tensor * ggml_map_custom1_inplace( // ggml_map_custom2 -struct ggml_map_custom2_op_params { - ggml_custom2_op_t fun; - int n_tasks; - void * userdata; -}; static struct ggml_tensor * ggml_map_custom2_impl( struct ggml_context * ctx, @@ -6202,11 +6167,6 @@ struct ggml_tensor * ggml_map_custom2_inplace( // ggml_map_custom3 -struct ggml_map_custom3_op_params { - ggml_custom3_op_t fun; - int n_tasks; - void * userdata; -}; static struct ggml_tensor * ggml_map_custom3_impl( struct ggml_context * ctx, @@ -14475,10 +14435,6 @@ static void ggml_hash_set_free(struct ggml_hash_set hash_set) { free(hash_set.keys); } -struct hash_map { - struct ggml_hash_set set; - struct ggml_tensor ** vals; -}; static struct hash_map * ggml_new_hash_map(size_t size) { struct hash_map * result = (hash_map *)malloc(sizeof(struct hash_map)); @@ -15734,7 +15690,7 @@ typedef int ggml_lock_t; #define GGML_LOCK_INITIALIZER 0 -typedef pthread_t ggml_thread_t; + #define ggml_thread_create pthread_create #define ggml_thread_join pthread_join @@ -15824,28 +15780,7 @@ static void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(threa static void clear_numa_thread_affinity(void) {} #endif -struct ggml_compute_state_shared { - const struct ggml_cgraph * cgraph; - const struct ggml_cplan * cplan; - int64_t perf_node_start_cycles; - int64_t perf_node_start_time_us; - - const int n_threads; - - // synchronization primitives - atomic_int n_active; // num active threads - atomic_int node_n; // active graph node - - bool (*abort_callback)(void * data); // abort ggml_graph_compute when true - void * abort_callback_data; -}; - -struct ggml_compute_state { - ggml_thread_t thrd; - int ith; - struct ggml_compute_state_shared * shared; -}; static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; @@ -17456,12 +17391,6 @@ static enum ggml_opt_result ggml_opt_adam( // https://github.com/chokkan/liblbfgs // -struct ggml_lbfgs_iteration_data { - float alpha; - float ys; - float * s; - float * y; -}; static enum ggml_opt_result linesearch_backtracking( const struct ggml_opt_params * params, @@ -18328,71 +18257,6 @@ static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); //}; static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); -union gguf_value { - uint8_t uint8; - int8_t int8; - uint16_t uint16; - int16_t int16; - uint32_t uint32; - int32_t int32; - float float32; - uint64_t uint64; - int64_t int64; - double float64; - bool bool_; - - struct gguf_str str; - - struct { - enum gguf_type type; - - uint64_t n; // GGUFv2 - void * data; - } arr; -}; - -struct gguf_kv { - struct gguf_str key; - - enum gguf_type type; - union gguf_value value; -}; - -struct gguf_header { - char magic[4]; - uint32_t version; - uint64_t n_tensors; // GGUFv2 - uint64_t n_kv; // GGUFv2 -}; - -struct gguf_tensor_info { - struct gguf_str name; - - uint32_t n_dims; - uint64_t ne[GGML_MAX_DIMS]; - - enum ggml_type type; - - uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` - - // for writing API - const void * data; - size_t size; -}; - -struct gguf_context { - struct gguf_header header; - - struct gguf_kv * kv; - struct gguf_tensor_info * infos; - - size_t alignment; - size_t offset; // offset of `data` from beginning of file - size_t size; // size of `data` in bytes - - //uint8_t * padding; - void * data; -}; static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) { const size_t n = fread(dst, 1, size, file); @@ -19185,11 +19049,6 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo // fwrite(val, sizeof(char), size, file); //} -struct gguf_buf { - void * data; - size_t size; - size_t offset; -}; static struct gguf_buf gguf_buf_init(size_t size) { struct gguf_buf buf = { diff --git a/llama-internal.hpp b/llama-internal.hpp index 5d560372d..33cf39e5d 100644 --- a/llama-internal.hpp +++ b/llama-internal.hpp @@ -1,5 +1,5 @@ #include - +#include enum llm_arch { LLM_ARCH_LLAMA, LLM_ARCH_FALCON, @@ -90,7 +90,7 @@ enum llama_fver { struct LLM_KV { LLM_KV(llm_arch arch) : arch(arch) {} - + llm_arch arch; std::string operator()(llm_kv kv) const; // moved to llama.cpp file @@ -196,7 +196,7 @@ struct llama_buffer { // useful in cases where CUDA can try to allocate PINNED memory bool fallback = false; - void resize(size_t n) ; + void resize(size_t n) ; ~llama_buffer(); @@ -293,9 +293,9 @@ struct llama_vocab { struct llama_mmap { void * addr; size_t size; - + llama_mmap(const llama_mmap &) = delete; - + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false); ~llama_mmap(); @@ -371,8 +371,8 @@ struct llama_mlock { #undef MLOCK_SUGGESTION static void raw_unlock(void * addr, size_t size); #elif defined(_WIN32) - static constexpr bool SUPPORTED = true; - static size_t lock_granularity(); + static constexpr bool SUPPORTED = true; + static size_t lock_granularity(); bool raw_lock(void * ptr, size_t len) const ; static void raw_unlock(void * ptr, size_t len); #else @@ -516,3 +516,381 @@ struct LLM_TN { std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const ; }; + + +struct llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + llama_file(const char * fname, const char * mode) ; + size_t tell() const; + void seek(size_t offset, int whence) const; + void read_raw(void * ptr, size_t len) const; + uint32_t read_u32() const; + void write_raw(const void * ptr, size_t len) const ; + void write_u32(std::uint32_t val) const; + ~llama_file(); + +}; + + +struct llama_state { + llama_state(); + // We save the log callback globally + ggml_log_callback log_callback; + void * log_callback_user_data = nullptr; +}; + + + +struct llama_model_loader { + int n_kv = 0; + int n_tensors = 0; + int n_created = 0; + + int64_t n_elements = 0; + size_t n_bytes = 0; + + bool use_mmap = false; + + llama_file file; + llama_ftype ftype; + llama_fver fver; + + std::unique_ptr mapping; + + struct gguf_context * ctx_gguf = NULL; + struct ggml_context * ctx_meta = NULL; + + llama_model_loader(const std::string & fname, bool use_mmap) ; + + ~llama_model_loader(); + + std::string get_arch_name() const; + + enum llm_arch get_arch() const ; + const char * get_tensor_name(int i) const; + + struct ggml_tensor * get_tensor_meta(int i) const; + + void calc_sizes(size_t & ctx_size_p, size_t & mmapped_size_p) const; + + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend_type backend) ; + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend_type backend) ; + + void done_getting_tensors() const; + + size_t file_offset(const char * name) const; + + + void load_data_for(struct ggml_tensor * cur) const ; + void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) ; +}; + +struct llama_data_context { + virtual void write(const void * src, size_t size) = 0; + virtual size_t get_size_written() = 0; + virtual ~llama_data_context() = default; +}; + +struct llama_data_buffer_context : llama_data_context { + uint8_t * ptr; + size_t size_written = 0; + llama_data_buffer_context(uint8_t * p) ; + void write(const void * src, size_t size) override ; + size_t get_size_written() override ; +}; + +struct llama_data_file_context : llama_data_context { + llama_file * file; + size_t size_written = 0; + llama_data_file_context(llama_file * f); + size_t get_size_written() override ; + void write(const void * src, size_t size); +}; + + +struct llama_beam { + std::vector tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Initialize end-of-beam to false. Callback sets this to true. + // Sort beams by probability. In case of ties, prefer beams at eob. + bool operator<(const llama_beam & rhs) const ; + void shift_tokens(const size_t n) ; + llama_beam_view view() const; +}; + +// A struct for calculating logit-related info. +struct llama_logit_info { + const float * const logits; + const int n_vocab; + const float max_l; + const float normalizer; + struct sum_exp { + float max_l; + float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } + }; + llama_logit_info(llama_context * ctx); + llama_token_data get_token_data(const llama_token token_id) const ; + std::vector top_k(size_t k) ; + float probability_from_logit(float logit) const ; +}; + + +struct llama_beam_search_data { + llama_context * ctx; + size_t n_beams; + int n_past; + int n_predict; + std::vector beams; + std::vector next_beams; + size_t common_prefix_length; + std::vector beam_views; + llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict); + void collapse_beams(const size_t beam_idx) ; + void fill_next_beams_by_top_probabilities(llama_beam & beam) ; + size_t find_common_prefix_length() ; + llama_beams_state get_beams_state(const bool last_call) ; + void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data); + static void renormalize_beam_probabilities(std::vector & beams) ; + size_t top_beam_index(); + void update_beams_from_beam_views(); +}; + +using llm_build_cb = std::function; + +enum llm_rope_type { + LLM_ROPE, + LLM_ROPE_NEOX, + LLM_ROPE_GLM, +}; + +enum llm_ffn_op_type { + LLM_FFN_SILU, + LLM_FFN_GELU, + LLM_FFN_RELU, + LLM_FFN_RELU_SQR, +}; + +enum llm_ffn_gate_type { + LLM_FFN_SEQ, + LLM_FFN_PAR, // ffn_gate is parallel to ffn_up +}; + +enum llm_norm_type { + LLM_NORM, + LLM_NORM_RMS, +}; + +struct llm_build_context { + const llama_model & model; + const llama_hparams & hparams; + const llama_cparams & cparams; + const llama_batch & batch; + const llama_kv_cache & kv_self; + + const int64_t n_embd; + const int64_t n_layer; + const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) + const int64_t n_head; + const int64_t n_head_kv; + const int64_t n_embd_head; + const int64_t n_embd_gqa; + + const float freq_base; + const float freq_scale; + const float ext_factor; + const float attn_factor; + const float beta_fast; + const float beta_slow; + const float norm_eps; + const float norm_rms_eps; + + const int32_t n_tokens; + const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx) + const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t n_orig_ctx; + + const bool do_rope_shift; + + const llm_build_cb & cb; + + llama_buffer & buf_compute; + + struct ggml_context * ctx0 = nullptr; + + // TODO: consider making the entire interface noexcept + llm_build_context( + llama_context & lctx, + const llama_batch & batch, + const llm_build_cb & cb, + bool worst_case); + + void init() ; + void free() ; + struct ggml_cgraph * build_llama() ; + struct ggml_cgraph * build_baichuan() ; + struct ggml_cgraph * build_falcon() ; + struct ggml_cgraph * build_starcoder() ; + struct ggml_cgraph * build_persimmon() ; + struct ggml_cgraph * build_refact() ; + struct ggml_cgraph * build_bloom() ; + struct ggml_cgraph * build_mpt() ; + struct ggml_cgraph * build_stablelm(); +}; + + +enum llm_offload_func_e { + OFFLOAD_FUNC_NOP, + OFFLOAD_FUNC, + OFFLOAD_FUNC_KQ, + OFFLOAD_FUNC_V, + OFFLOAD_FUNC_NR, + OFFLOAD_FUNC_EMB, + OFFLOAD_FUNC_OUT, +}; + +struct llm_offload_trie { + struct node { + ~node() ; + node * children[256] = { nullptr }; + llm_offload_func_e func = OFFLOAD_FUNC_NOP; + }; + node * root = nullptr; + llm_offload_trie(); + llm_offload_trie(const std::unordered_map & map) ; + ~llm_offload_trie(); + void add(const char * name, llm_offload_func_e func); + llm_offload_func_e find(const char * name) const; + +}; + +struct llm_symbol { + using index = int; + index prev; + index next; + const char * text; + size_t n; +}; + + +struct llm_bigram_spm { + struct comparator { + bool operator()(llm_bigram_spm & l, llm_bigram_spm & r); + }; + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + float score; + size_t size; +}; + +struct llm_tokenizer_spm { + llm_tokenizer_spm(const llama_vocab & vocab); + void tokenize(const std::string & text, std::vector & output); + + +private: + void resegment(llm_symbol & symbol, std::vector & output) ; + void try_add_bigram(int left, int right) ; + const llama_vocab & vocab; + + std::vector symbols; + llm_bigram_spm::queue work_queue; + + std::map> rev_merge; +}; + +// BPE tokenizer +// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License] +// tried to simplify unicode stuff, so most likely does not work 100% correctly! + +// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused + +struct llm_bigram_bpe { + struct comparator { + bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const ; + }; + + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + std::string text; + int rank; + size_t size; +}; + +struct llm_tokenizer_bpe { + llm_tokenizer_bpe(const llama_vocab & vocab); + + void tokenize(const std::string & text, std::vector & output); + +private: + void add_new_bigram(int left, int right) ; + + std::vector bpe_gpt2_preprocess(const std::string & text) ; + + const llama_vocab & vocab; + + std::vector symbols; + std::vector symbols_final; + + llm_bigram_bpe::queue work_queue; +}; + +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{ + FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, + FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT +} FRAGMENT_BUFFER_VARIANT_TYPE; + +struct fragment_buffer_variant{ + fragment_buffer_variant(llama_vocab::id _token); + fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length); + const FRAGMENT_BUFFER_VARIANT_TYPE type; + const llama_vocab::id token; + const std::string _dummy; + const std::string & raw_text; + const uint64_t offset; + const uint64_t length; +}; + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar { + const std::vector> rules; + std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + llama_partial_utf8 partial_utf8; +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +struct quantize_state_internal { + const llama_model & model; + const llama_model_quantize_params * params; + + int n_attention_wv = 0; + int n_feed_forward_w2 = 0; + int i_attention_wv = 0; + int i_feed_forward_w2 = 0; + + int n_k_quantized = 0; + int n_fallback = 0; + + quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params) + : model(model) + , params(params) + {} +}; diff --git a/llama.cpp b/llama.cpp index 675d147c8..5682234e7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -639,12 +639,8 @@ llama_buffer::~llama_buffer() { } -struct llama_file { - // use FILE * so we don't have to re-open the file to mmap - FILE * fp; - size_t size; - llama_file(const char * fname, const char * mode) { +llama_file::llama_file(const char * fname, const char * mode) { fp = std::fopen(fname, mode); if (fp == NULL) { throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); @@ -654,7 +650,7 @@ struct llama_file { seek(0, SEEK_SET); } - size_t tell() const { +size_t llama_file::tell() const { #ifdef _WIN32 __int64 ret = _ftelli64(fp); #else @@ -664,7 +660,8 @@ struct llama_file { return (size_t) ret; } - void seek(size_t offset, int whence) const { +void llama_file::seek(size_t offset, int whence) const { + #ifdef _WIN32 int ret = _fseeki64(fp, (__int64) offset, whence); #else @@ -673,7 +670,7 @@ struct llama_file { GGML_ASSERT(ret == 0); // same } - void read_raw(void * ptr, size_t len) const { +void llama_file::read_raw(void * ptr, size_t len) const { if (len == 0) { return; } @@ -687,13 +684,13 @@ struct llama_file { } } - uint32_t read_u32() const { +uint32_t llama_file::read_u32() const { uint32_t ret; read_raw(&ret, sizeof(ret)); return ret; } - void write_raw(const void * ptr, size_t len) const { +void llama_file::write_raw(const void * ptr, size_t len) const { if (len == 0) { return; } @@ -704,16 +701,16 @@ struct llama_file { } } - void write_u32(std::uint32_t val) const { +void llama_file::write_u32(std::uint32_t val) const { write_raw(&val, sizeof(val)); } - ~llama_file() { +llama_file::~llama_file() { if (fp) { std::fclose(fp); } } -}; + // @@ -985,12 +982,6 @@ static std::string llama_token_to_piece(const struct llama_context * ctx, llama_ // globals // -struct llama_state { - // We save the log callback globally - ggml_log_callback log_callback = llama_log_callback_default; - void * log_callback_user_data = nullptr; -}; - static llama_state g_state; @@ -1276,26 +1267,8 @@ static std::string llama_format_tensor_shape(const struct ggml_tensor * t) { return buf; } -struct llama_model_loader { - int n_kv = 0; - int n_tensors = 0; - int n_created = 0; - int64_t n_elements = 0; - size_t n_bytes = 0; - - bool use_mmap = false; - - llama_file file; - llama_ftype ftype; - llama_fver fver; - - std::unique_ptr mapping; - - struct gguf_context * ctx_gguf = NULL; - struct ggml_context * ctx_meta = NULL; - - llama_model_loader(const std::string & fname, bool use_mmap) : file(fname.c_str(), "rb") { +llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap) : file(fname.c_str(), "rb") { struct gguf_init_params params( /*.no_alloc =*/ true, /*.ctx = */ &ctx_meta @@ -1409,7 +1382,7 @@ struct llama_model_loader { this->use_mmap = use_mmap; } - ~llama_model_loader() { + llama_model_loader::~llama_model_loader() { if (ctx_gguf) { gguf_free(ctx_gguf); } @@ -1418,7 +1391,7 @@ struct llama_model_loader { } } - std::string get_arch_name() const { + std::string llama_model_loader::get_arch_name() const { const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); std::string arch_name; @@ -1427,21 +1400,21 @@ struct llama_model_loader { return arch_name; } - enum llm_arch get_arch() const { + enum llm_arch llama_model_loader::get_arch() const { const std::string arch_name = get_arch_name(); return llm_arch_from_string(arch_name); } - const char * get_tensor_name(int i) const { + const char * llama_model_loader::get_tensor_name(int i) const { return gguf_get_tensor_name(ctx_gguf, i); } - struct ggml_tensor * get_tensor_meta(int i) const { + struct ggml_tensor * llama_model_loader::get_tensor_meta(int i) const { return ggml_get_tensor(ctx_meta, get_tensor_name(i)); } - void calc_sizes(size_t & ctx_size_p, size_t & mmapped_size_p) const { + void llama_model_loader::calc_sizes(size_t & ctx_size_p, size_t & mmapped_size_p) const { ctx_size_p = 0; mmapped_size_p = 0; @@ -1452,7 +1425,7 @@ struct llama_model_loader { } } - struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend_type backend) { + struct ggml_tensor * llama_model_loader::create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend_type backend) { if (backend != GGML_BACKEND_CPU) { ggml_set_no_alloc(ctx, true); } @@ -1470,7 +1443,7 @@ struct llama_model_loader { return tensor; } - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend_type backend) { + struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend_type backend) { struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str()); if (cur == NULL) { @@ -1503,13 +1476,13 @@ struct llama_model_loader { return create_tensor_for(ctx, cur, backend); } - void done_getting_tensors() const { + void llama_model_loader::done_getting_tensors() const { if (n_created != n_tensors) { throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); } } - size_t file_offset(const char * name) const { + size_t llama_model_loader::file_offset(const char * name) const { const int idx = gguf_find_tensor(ctx_gguf, name); if (idx < 0) { @@ -1519,7 +1492,7 @@ struct llama_model_loader { return gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, idx); } - void load_data_for(struct ggml_tensor * cur) const { + void llama_model_loader::load_data_for(struct ggml_tensor * cur) const { const size_t offs = file_offset(ggml_get_name(cur)); if (use_mmap) { @@ -1530,7 +1503,7 @@ struct llama_model_loader { } } - void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { + void llama_model_loader::load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { size_t size_data = 0; size_t size_lock = 0; size_t size_pref = 0; // prefetch @@ -1606,7 +1579,7 @@ struct llama_model_loader { done_size += ggml_nbytes(cur); } } -}; + //}; // // load LLaMA models @@ -2940,30 +2913,6 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con // llm_build // -using llm_build_cb = std::function; - -enum llm_rope_type { - LLM_ROPE, - LLM_ROPE_NEOX, - LLM_ROPE_GLM, -}; - -enum llm_ffn_op_type { - LLM_FFN_SILU, - LLM_FFN_GELU, - LLM_FFN_RELU, - LLM_FFN_RELU_SQR, -}; - -enum llm_ffn_gate_type { - LLM_FFN_SEQ, - LLM_FFN_PAR, // ffn_gate is parallel to ffn_up -}; - -enum llm_norm_type { - LLM_NORM, - LLM_NORM_RMS, -}; static struct ggml_tensor * llm_build_inp_embd( struct ggml_context * ctx, @@ -3278,45 +3227,10 @@ static struct ggml_tensor * llm_build_kqv( return cur; } -struct llm_build_context { - const llama_model & model; - const llama_hparams & hparams; - const llama_cparams & cparams; - const llama_batch & batch; - const llama_kv_cache & kv_self; - - const int64_t n_embd; - const int64_t n_layer; - const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) - const int64_t n_head; - const int64_t n_head_kv; - const int64_t n_embd_head; - const int64_t n_embd_gqa; - - const float freq_base; - const float freq_scale; - const float ext_factor; - const float attn_factor; - const float beta_fast; - const float beta_slow; - const float norm_eps; - const float norm_rms_eps; - - const int32_t n_tokens; - const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx) - const int32_t kv_head; // index of where we store new KV data in the cache - const int32_t n_orig_ctx; - - const bool do_rope_shift; - - const llm_build_cb & cb; - - llama_buffer & buf_compute; - - struct ggml_context * ctx0 = nullptr; +// struct llm_build_context { // TODO: consider making the entire interface noexcept - llm_build_context( +llm_build_context::llm_build_context( llama_context & lctx, const llama_batch & batch, const llm_build_cb & cb, @@ -3353,7 +3267,7 @@ struct llm_build_context { // all initializations should be done in init() } - void init() { +void llm_build_context::init() { struct ggml_init_params params( //.mem_size = buf_compute.size, @@ -3366,14 +3280,14 @@ struct llm_build_context { ctx0 = ggml_init(params); } - void free() { + void llm_build_context::free() { if (ctx0) { ggml_free(ctx0); ctx0 = nullptr; } } - struct ggml_cgraph * build_llama() { + struct ggml_cgraph * llm_build_context::build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -3485,7 +3399,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_baichuan() { +struct ggml_cgraph * llm_build_context::build_baichuan() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_tensor * cur; @@ -3605,7 +3519,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_falcon() { +struct ggml_cgraph * llm_build_context::build_falcon() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_tensor * cur; @@ -3727,7 +3641,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_starcoder() { +struct ggml_cgraph * llm_build_context::build_starcoder() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_tensor * cur; @@ -3826,7 +3740,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_persimmon() { + struct ggml_cgraph * llm_build_context::build_persimmon() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); const int64_t n_rot = n_embd_head / 2; @@ -4036,7 +3950,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_refact() { +struct ggml_cgraph * llm_build_context::build_refact() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_tensor * cur; @@ -4127,7 +4041,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_bloom() { +struct ggml_cgraph * llm_build_context::build_bloom() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_tensor * cur; @@ -4221,7 +4135,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_mpt() { +struct ggml_cgraph * llm_build_context::build_mpt() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_tensor * cur; @@ -4320,7 +4234,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_stablelm() { +struct ggml_cgraph * llm_build_context::build_stablelm() { struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * cur; @@ -4432,27 +4346,18 @@ struct llm_build_context { return gf; } -}; + // // tensor offloading helpers // // TODO: will be removed with backend v2 -enum llm_offload_func_e { - OFFLOAD_FUNC_NOP, - OFFLOAD_FUNC, - OFFLOAD_FUNC_KQ, - OFFLOAD_FUNC_V, - OFFLOAD_FUNC_NR, - OFFLOAD_FUNC_EMB, - OFFLOAD_FUNC_OUT, -}; // TODO: will be removed with backend v2 -struct llm_offload_trie { - struct node { - ~node() { +//struct llm_offload_trie { +// struct node { +llm_offload_trie::node::~node() { for (int i = 0; i < 256; ++i) { if (children[i]) { delete children[i]; @@ -4460,28 +4365,28 @@ struct llm_offload_trie { } } - node * children[256] = { nullptr }; - llm_offload_func_e func = OFFLOAD_FUNC_NOP; - }; +// node * children[256] = { nullptr }; +// llm_offload_func_e func = OFFLOAD_FUNC_NOP; +// }; - llm_offload_trie() { +llm_offload_trie::llm_offload_trie() { root = new node; } - llm_offload_trie(const std::unordered_map & map) { - root = new node; +llm_offload_trie::llm_offload_trie(const std::unordered_map & map) { + root = new node; + + for (const auto & kv : map) { + add(kv.first, kv.second); + } +} - for (const auto & kv : map) { - add(kv.first, kv.second); - } - } +llm_offload_trie::~llm_offload_trie() { + delete root; +} - ~llm_offload_trie() { - delete root; - } - - void add(const char * name, llm_offload_func_e func) { - node * cur = root; +void llm_offload_trie::add(const char * name, llm_offload_func_e func) { + node * cur = root; for (int i = 0; ; ++i) { const uint8_t c = name[i]; @@ -4500,7 +4405,7 @@ struct llm_offload_trie { cur->func = func; } - llm_offload_func_e find(const char * name) const { +llm_offload_func_e llm_offload_trie::find(const char * name) const { const node * cur = root; for (int i = 0; ; ++i) { @@ -4520,8 +4425,8 @@ struct llm_offload_trie { return cur->func; } - node * root = nullptr; -}; +// node * root = nullptr; +//}; // TODO: will be removed with backend v2 static const std::unordered_map k_offload_map = { @@ -5255,13 +5160,6 @@ static void llama_unescape_whitespace(std::string & word) { replace_all(word, "\xe2\x96\x81", " "); } -struct llm_symbol { - using index = int; - index prev; - index next; - const char * text; - size_t n; -}; static_assert(std::is_trivially_copyable::value, "llm_symbol is not trivially copyable"); @@ -5269,24 +5167,16 @@ static_assert(std::is_trivially_copyable::value, "llm_symbol is not // original implementation: // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 -struct llm_bigram_spm { - struct comparator { - bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) { - return (l.score < r.score) || (l.score == r.score && l.left > r.left); - } - }; - using queue_storage = std::vector; - using queue = std::priority_queue; - llm_symbol::index left; - llm_symbol::index right; - float score; - size_t size; -}; -struct llm_tokenizer_spm { - llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {} +bool llm_bigram_spm::comparator::operator()(llm_bigram_spm & l, llm_bigram_spm & r) { + return (l.score < r.score) || (l.score == r.score && l.left > r.left); +} - void tokenize(const std::string & text, std::vector & output) { + +// struct llm_tokenizer_spm { +llm_tokenizer_spm::llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {} + +void llm_tokenizer_spm::tokenize(const std::string & text, std::vector & output) { // split string into utf8 chars int index = 0; size_t offs = 0; @@ -5344,8 +5234,8 @@ struct llm_tokenizer_spm { } } -private: - void resegment(llm_symbol & symbol, std::vector & output) { +//private: +void llm_tokenizer_spm::resegment(llm_symbol & symbol, std::vector & output) { auto text = std::string(symbol.text, symbol.n); auto token = vocab.token_to_id.find(text); @@ -5370,7 +5260,7 @@ private: resegment(symbols[p->second.second], output); } - void try_add_bigram(int left, int right) { +void llm_tokenizer_spm::try_add_bigram(int left, int right) { if (left == -1 || right == -1) { return; } @@ -5400,13 +5290,6 @@ private: rev_merge[text] = std::make_pair(left, right); } - const llama_vocab & vocab; - - std::vector symbols; - llm_bigram_spm::queue work_queue; - - std::map> rev_merge; -}; // BPE tokenizer // adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License] @@ -5414,26 +5297,15 @@ private: // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused -struct llm_bigram_bpe { - struct comparator { - bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { - return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); - } - }; - using queue_storage = std::vector; - using queue = std::priority_queue; - llm_symbol::index left; - llm_symbol::index right; - std::string text; - int rank; - size_t size; -}; +bool llm_bigram_bpe::comparator::operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { + return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); +} -struct llm_tokenizer_bpe { - llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {} +//struct llm_tokenizer_bpe { +llm_tokenizer_bpe::llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {} - void tokenize(const std::string & text, std::vector & output) { + void llm_tokenizer_bpe::tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; auto word_collection = bpe_gpt2_preprocess(text); @@ -5534,8 +5406,8 @@ struct llm_tokenizer_bpe { } } -private: - void add_new_bigram(int left, int right) { +//private: +void llm_tokenizer_bpe::add_new_bigram(int left, int right) { if (left == -1 || right == -1) { return; } @@ -5562,7 +5434,7 @@ private: work_queue.push(bigram); } - std::vector bpe_gpt2_preprocess(const std::string & text) { + std::vector llm_tokenizer_bpe::bpe_gpt2_preprocess(const std::string & text) { std::vector bpe_words; std::vector bpe_encoded_words; @@ -5701,28 +5573,17 @@ private: return bpe_encoded_words; } - const llama_vocab & vocab; - std::vector symbols; - std::vector symbols_final; - llm_bigram_bpe::queue work_queue; -}; - -typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{ - FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, - FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT -} FRAGMENT_BUFFER_VARIANT_TYPE; - -struct fragment_buffer_variant{ - fragment_buffer_variant(llama_vocab::id _token) +//struct fragment_buffer_variant{ +fragment_buffer_variant::fragment_buffer_variant(llama_vocab::id _token) : type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), token(_token), raw_text(_dummy), offset(0), length(0){} - fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) +fragment_buffer_variant::fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) : type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), token((llama_vocab::id)-1), @@ -5734,13 +5595,6 @@ struct fragment_buffer_variant{ GGML_ASSERT( offset + length <= raw_text.length() ); } - const FRAGMENT_BUFFER_VARIANT_TYPE type; - const llama_vocab::id token; - const std::string _dummy; - const std::string & raw_text; - const uint64_t offset; - const uint64_t length; -}; // #define PRETOKENIZERDEBUG @@ -5946,24 +5800,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & // grammar - internal // -struct llama_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence -}; - -struct llama_grammar { - const std::vector> rules; - std::vector> stacks; - - // buffer for partially generated UTF-8 sequence from accepted tokens - llama_partial_utf8 partial_utf8; -}; - -struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; - llama_partial_utf8 partial_utf8; -}; // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. @@ -6895,22 +6731,19 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar // Beam search // -struct llama_beam { - std::vector tokens; - float p; // Cumulative beam probability (renormalized relative to all beams) - bool eob; // Initialize end-of-beam to false. Callback sets this to true. - // Sort beams by probability. In case of ties, prefer beams at eob. - bool operator<(const llama_beam & rhs) const { +// llama_beam { + +bool llama_beam::operator<(const llama_beam & rhs) const { return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob); } // Shift off first n tokens and discard them. - void shift_tokens(const size_t n) { +void llama_beam::shift_tokens(const size_t n) { if (n) { std::copy(tokens.begin() + n, tokens.end(), tokens.begin()); tokens.resize(tokens.size() - n); } } - llama_beam_view view() const { +llama_beam_view llama_beam::view() const { llama_beam_view bv = { .tokens =tokens.data(), .n_tokens= tokens.size(), @@ -6919,25 +6752,25 @@ struct llama_beam { }; return bv; } -}; + // A struct for calculating logit-related info. -struct llama_logit_info { - const float * const logits; - const int n_vocab; - const float max_l; - const float normalizer; - struct sum_exp { - float max_l; - float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } - }; - llama_logit_info(llama_context * ctx) +//struct llama_logit_info { +// const float * const logits; +// const int n_vocab; +// const float max_l; +// const float normalizer; +// struct sum_exp { +// float max_l; +// float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } +// }; +llama_logit_info::llama_logit_info(llama_context * ctx) : logits(llama_get_logits(ctx)) , n_vocab(llama_n_vocab(llama_get_model(ctx))) , max_l(*std::max_element(logits, logits + n_vocab)) , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) { } - llama_token_data get_token_data(const llama_token token_id) const { +llama_token_data llama_logit_info::get_token_data(const llama_token token_id) const { constexpr auto p = std::numeric_limits::quiet_NaN(); // never used llama_token_data dd( token_id, @@ -6947,7 +6780,7 @@ struct llama_logit_info { return dd; } // Return top k token_data by logit. - std::vector top_k(size_t k) { +std::vector llama_logit_info::top_k(size_t k) { std::vector min_heap; // min-heap by logit const llama_token k_min = std::min(static_cast(k), n_vocab); min_heap.reserve(k_min); @@ -6966,26 +6799,15 @@ struct llama_logit_info { } return min_heap; } - float probability_from_logit(float logit) const { +float llama_logit_info::probability_from_logit(float logit) const { return normalizer * std::exp(logit - max_l); } -}; -struct llama_beam_search_data { - llama_context * ctx; - size_t n_beams; - int n_past; - int n_predict; - std::vector beams; - std::vector next_beams; - // Re-calculated on each loop iteration - size_t common_prefix_length; +//struct llama_beam_search_data { - // Used to communicate to/from callback on beams state. - std::vector beam_views; - llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict) +llama_beam_search_data::llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict) : ctx(ctx) , n_beams(n_beams) , n_past(n_past) @@ -6996,7 +6818,7 @@ struct llama_beam_search_data { } // Collapse beams to a single beam given by index. - void collapse_beams(const size_t beam_idx) { +void llama_beam_search_data::collapse_beams(const size_t beam_idx) { if (0u < beam_idx) { std::swap(beams[0], beams[beam_idx]); } @@ -7008,7 +6830,7 @@ struct llama_beam_search_data { // * Gather elements until the vector is full, then call std::make_heap() on it. // * If the heap is full and a new element is found that should be included, pop the // least element to the back(), replace it with the new, then push it into the heap. - void fill_next_beams_by_top_probabilities(llama_beam & beam) { +void llama_beam_search_data::fill_next_beams_by_top_probabilities(llama_beam & beam) { // Min-heaps use a greater-than comparator. const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; }; if (beam.eob) { @@ -7063,7 +6885,7 @@ struct llama_beam_search_data { // Find common_prefix_length based on beams. // Requires beams is not empty. - size_t find_common_prefix_length() { +size_t llama_beam_search_data::find_common_prefix_length() { size_t common_prefix_length = beams[0].tokens.size(); for (size_t i = 1 ; i < beams.size() ; ++i) { common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size()); @@ -7079,7 +6901,7 @@ struct llama_beam_search_data { // Construct beams_state to send back to caller via the callback function. // Side effect: set common_prefix_length = find_common_prefix_length(); - llama_beams_state get_beams_state(const bool last_call) { +llama_beams_state llama_beam_search_data::get_beams_state(const bool last_call) { for (size_t i = 0 ; i < beams.size() ; ++i) { beam_views[i] = beams[i].view(); } @@ -7098,7 +6920,7 @@ struct llama_beam_search_data { // * any of the beams have not yet reached end-of-beam (eob), AND // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence // (since all other beam probabilities can only decrease) - void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { +void llama_beam_search_data::loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob. const auto not_eob = [](const llama_beam & beam) { return !beam.eob; }; for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) && @@ -7125,25 +6947,25 @@ struct llama_beam_search_data { // As beams grow, the cumulative probabilities decrease. // Renormalize them to avoid floating point underflow. - static void renormalize_beam_probabilities(std::vector & beams) { +void llama_beam_search_data::renormalize_beam_probabilities(std::vector & beams) { const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; }; const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; }); } // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering. - size_t top_beam_index() { +size_t llama_beam_search_data::top_beam_index() { return std::max_element(beams.begin(), beams.end()) - beams.begin(); } // Copy (p,eob) for each beam which may have been changed by the callback. - void update_beams_from_beam_views() { +void llama_beam_search_data::update_beams_from_beam_views() { for (size_t i = 0 ; i < beams.size() ; ++i) { beams[i].p = beam_views[i].p; beams[i].eob = beam_views[i].eob; } } -}; + void llama_beam_search(llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, @@ -7169,23 +6991,6 @@ struct no_init { no_init() { /* do nothing */ } }; -struct quantize_state_internal { - const llama_model & model; - const llama_model_quantize_params * params; - - int n_attention_wv = 0; - int n_feed_forward_w2 = 0; - int i_attention_wv = 0; - int i_feed_forward_w2 = 0; - - int n_k_quantized = 0; - int n_fallback = 0; - - quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params) - : model(model) - , params(params) - {} -}; static void llama_convert_tensor_internal( struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, @@ -8442,45 +8247,32 @@ size_t llama_get_state_size(const struct llama_context * ctx) { return s_total; } -// llama_context_data -struct llama_data_context { - virtual void write(const void * src, size_t size) = 0; - virtual size_t get_size_written() = 0; - virtual ~llama_data_context() = default; -}; -struct llama_data_buffer_context : llama_data_context { - uint8_t * ptr; - size_t size_written = 0; - llama_data_buffer_context(uint8_t * p) : ptr(p) {} + llama_data_buffer_context::llama_data_buffer_context(uint8_t * p) : ptr(p) {} - void write(const void * src, size_t size) override { - memcpy(ptr, src, size); - ptr += size; - size_written += size; - } +void llama_data_buffer_context::write(const void * src, size_t size) { + memcpy(ptr, src, size); + ptr += size; + size_written += size; +} - size_t get_size_written() override { - return size_written; - } -}; +size_t llama_data_buffer_context::get_size_written() { + return size_written; +} -struct llama_data_file_context : llama_data_context { - llama_file * file; - size_t size_written = 0; - llama_data_file_context(llama_file * f) : file(f) {} + +llama_data_file_context::llama_data_file_context(llama_file * f) : file(f) {} - void write(const void * src, size_t size) override { - file->write_raw(src, size); - size_written += size; - } +void llama_data_file_context::write(const void * src, size_t size) { + file->write_raw(src, size); + size_written += size; +} - size_t get_size_written() override { - return size_written; - } -}; +size_t llama_data_file_context::get_size_written() { + return size_written; +} /** copy state data into either a buffer or file depending on the passed in context * @@ -9287,3 +9079,6 @@ llama_context::~llama_context() { ggml_allocr_free(alloc); } } +llama_state::llama_state(){ + log_callback= llama_log_callback_default; +} diff --git a/llama.h b/llama.h index cbde7990b..fa430896c 100644 --- a/llama.h +++ b/llama.h @@ -114,7 +114,7 @@ LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN, }; - typedef struct llama_token_data : refl::attr::usage::type{ + typedef struct llama_token_data { llama_token_data( llama_token id, float logit, float p): id( id),logit(logit),p(p){ } llama_token id; // token id @@ -122,7 +122,7 @@ float p; // probability of the token } llama_token_data; - typedef struct llama_token_data_array : refl::attr::usage::type{ + typedef struct llama_token_data_array { llama_token_data_array(llama_token_data * data, size_t size, bool sorted): @@ -146,7 +146,7 @@ // - seq_id : the sequence to which the respective token belongs // - logits : if zero, the logits for the respective token will not be output // - typedef struct llama_batch : refl::attr::usage::type{ + typedef struct llama_batch { llama_batch(int32_t n_tokens, llama_token * token, @@ -205,7 +205,7 @@ bool use_mlock; // force system to keep model in RAM }; - struct llama_context_params : refl::attr::usage::type{ + struct llama_context_params{ uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // prompt processing maximum batch size @@ -230,7 +230,7 @@ }; // model quantization parameters - typedef struct llama_model_quantize_params : refl::attr::usage::type{ + typedef struct llama_model_quantize_params { int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() enum llama_ftype ftype; // quantize to this llama_ftype bool allow_requantize; // allow quantizing non-f32/f16 tensors @@ -268,7 +268,7 @@ LLAMA_GRETYPE_CHAR_ALT = 6, }; - typedef struct llama_grammar_element : refl::attr::usage::type { + typedef struct llama_grammar_element { llama_grammar_element( enum llama_gretype type, uint32_t value // Unicode code point or rule ID ):type(type), value(value){} @@ -278,7 +278,7 @@ } llama_grammar_element; // performance timing information - struct llama_timings : refl::attr::usage::type{ + struct llama_timings { double t_start_ms; double t_end_ms; double t_load_ms; @@ -755,7 +755,7 @@ // Beam search // - struct llama_beam_view : refl::attr::usage::type{ + struct llama_beam_view { const llama_token * tokens; size_t n_tokens; @@ -767,7 +767,7 @@ // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. // These pointers are valid only during the synchronous callback, so should not be saved. - struct llama_beams_state : refl::attr::usage::type{ + struct llama_beams_state { struct llama_beam_view * beam_views; size_t n_beams; // Number of elements in beam_views[]. @@ -831,3 +831,5 @@ const std::vector> & llama_internal #endif // LLAMA_H + + diff --git a/print.hpp b/print.hpp index 551683a19..61be516c9 100644 --- a/print.hpp +++ b/print.hpp @@ -1,7 +1,4 @@ -//template void print_fields(const T& obj); - #include -//#include #include "llama.h" #include "ggml-internal.hpp" #include "llama-internal.hpp" @@ -56,9 +53,9 @@ REFL_FIELD(prompt_file ) REFL_FIELD(path_prompt_cache ) REFL_FIELD(input_prefix ) REFL_FIELD(input_suffix ) -//REFL_FIELD( antiprompt) +REFL_FIELD( antiprompt) REFL_FIELD(logdir ) -//REFL_FIELD( lora_adapter) +REFL_FIELD( lora_adapter) REFL_FIELD(lora_base ) REFL_FIELD( ppl_stride ) REFL_FIELD( ppl_output_type ) @@ -95,9 +92,6 @@ REFL_END REFL_TYPE(llama_sampling_params) REFL_END -REFL_TYPE(llama_buffer) -REFL_END - REFL_TYPE(llm_arch) REFL_END @@ -106,8 +100,8 @@ REFL_FIELD( params) REFL_FIELD( mirostat_mu) REFL_FIELD( grammar) REFL_FIELD( parsed_grammar) -//REFL_FIELD( prev) vector of ints -//REFL_FIELD( cur) +REFL_FIELD( prev) +REFL_FIELD( cur) REFL_END REFL_TYPE(llama_token_data ) @@ -183,87 +177,82 @@ REFL_TYPE(ggml_context_container) REFL_FIELD(context) REFL_END -// REFL_TYPE(ggml_numa_node) -// REFL_FIELD(cpus) -// REFL_FIELD(n_cpus) -// REFL_END + REFL_TYPE(ggml_numa_node) + REFL_FIELD(cpus) + REFL_FIELD(n_cpus) + REFL_END -// REFL_TYPE(ggml_numa_nodes) -// REFL_FIELD(nodes) -// REFL_FIELD(n_nodes) -// REFL_END + REFL_TYPE(ggml_numa_nodes) + REFL_FIELD(nodes) + REFL_FIELD(n_nodes) + REFL_END -// REFL_TYPE(ggml_state) -// REFL_FIELD(contexts) -// REFL_FIELD(numa) -// REFL_END + REFL_TYPE(ggml_state) + REFL_FIELD(contexts) + REFL_FIELD(numa) + REFL_END -// REFL_TYPE(gguf_str) -// REFL_FIELD(n) -// REFL_FIELD(data) -// REFL_END + REFL_TYPE(gguf_str) + REFL_FIELD(n) + REFL_FIELD(data) + REFL_END -// REFL_TYPE(ggml_map_custom1_op_params) -// REFL_FIELD(fun) -// REFL_FIELD(n_tasks) -// REFL_END + REFL_TYPE(ggml_map_custom1_op_params) + REFL_FIELD(fun) + REFL_FIELD(n_tasks) + REFL_END -// REFL_TYPE(ggml_map_custom2_op_params) -// REFL_FIELD(fun) -// REFL_FIELD(n_tasks) -// REFL_END - -// REFL_TYPE(ggml_map_custom3_op_params) -// REFL_FIELD(fun) -// REFL_FIELD(n_tasks) -// REFL_END - -// REFL_TYPE(hash_map) -// REFL_FIELD(set) -// REFL_FIELD(vals) -// REFL_END -// REFL_TYPE(ggml_compute_state_shared) -// REFL_FIELD(cgraph) -// REFL_FIELD(cplan) -// REFL_END -// REFL_TYPE(ggml_compute_state) -// REFL_FIELD(thrd) -// REFL_FIELD(ith) -// REFL_END -// REFL_TYPE(ggml_lbfgs_iteration_data) -// REFL_FIELD(alpha) -// REFL_FIELD(ys) -// REFL_END -//REFL_TYPE() -// REFL_FIELD(type) -//REFL_END -// REFL_TYPE(gguf_kv) -// REFL_FIELD(key) -// REFL_FIELD(type) -// REFL_END - -// REFL_TYPE(gguf_header) -// REFL_FIELD(magic) -// REFL_FIELD(version) -// REFL_END - -// REFL_TYPE(gguf_tensor_info) -// REFL_FIELD(name) -// REFL_FIELD(n_dims) -// REFL_END - -REFL_TYPE(gguf_context) -// REFL_FIELD(header) -// REFL_FIELD(kv) +REFL_TYPE(ggml_map_custom2_op_params) + REFL_FIELD(fun) + REFL_FIELD(n_tasks) REFL_END -// REFL_TYPE(gguf_buf) -// REFL_FIELD(data) -// REFL_FIELD(size) -// REFL_END +REFL_TYPE(ggml_map_custom3_op_params) + REFL_FIELD(fun) + REFL_FIELD(n_tasks) +REFL_END -//REFL_TYPE(llama_token_data) -//REFL_END +REFL_TYPE(hash_map) + REFL_FIELD(set) + REFL_FIELD(vals) +REFL_END +REFL_TYPE(ggml_compute_state_shared) + REFL_FIELD(cgraph) + REFL_FIELD(cplan) +REFL_END +REFL_TYPE(ggml_compute_state) + REFL_FIELD(thrd) + REFL_FIELD(ith) +REFL_END +REFL_TYPE(ggml_lbfgs_iteration_data) + REFL_FIELD(alpha) + REFL_FIELD(ys) +REFL_END + +REFL_TYPE(gguf_kv) + REFL_FIELD(key) + REFL_FIELD(type) +REFL_END + +REFL_TYPE(gguf_header) + REFL_FIELD(magic) + REFL_FIELD(version) +REFL_END + +REFL_TYPE(gguf_tensor_info) + REFL_FIELD(name) + REFL_FIELD(n_dims) +REFL_END + +REFL_TYPE(gguf_context) + REFL_FIELD(header) + REFL_FIELD(kv) +REFL_END + +REFL_TYPE(gguf_buf) + REFL_FIELD(data) + REFL_FIELD(size) +REFL_END REFL_TYPE(llama_model_params) @@ -290,55 +279,55 @@ REFL_TYPE(llama_beams_state) REFL_FIELD(beam_views) REFL_END -//REFL_TYPE(ggml_backend) -//REFL_END +REFL_TYPE(ggml_backend) +REFL_END REFL_TYPE(ggml_backend_buffer) REFL_END -//REFL_TYPE(ggml_allocr) -//REFL_END +REFL_TYPE(ggml_allocr) +REFL_END -//REFL_TYPE(ggml_tallocr) -//REFL_END +REFL_TYPE(ggml_tallocr) +REFL_END -//REFL_TYPE(ggml_gallocr) -//REFL_END +REFL_TYPE(ggml_gallocr) +REFL_END -//REFL_TYPE(llama_buffer) -//REFL_FIELD(data) -//REFL_FIELD(size) -//REFL_END +REFL_TYPE(llama_buffer) +REFL_FIELD(data) +REFL_FIELD(size) +REFL_END -// REFL_TYPE(llama_file) -// REFL_FIELD(fp) -// REFL_FIELD(size) -// REFL_END +REFL_TYPE(llama_file) +REFL_FIELD(fp) +REFL_FIELD(size) +REFL_END -// REFL_TYPE(llama_mmap) -// REFL_FIELD(addr) -// REFL_FIELD(size) -// REFL_END +REFL_TYPE(llama_mmap) +REFL_FIELD(addr) +REFL_FIELD(size) +REFL_END -// REFL_TYPE(llama_mlock) -// REFL_FIELD(addr) -// REFL_FIELD(size) -// REFL_END +REFL_TYPE(llama_mlock) + REFL_FIELD(addr) + REFL_FIELD(size) +REFL_END -//REFL_TYPE(llama_state) -// REFL_FIELD(log_callback) -// REFL_FIELD(log_callback_user_data) -// REFL_END +REFL_TYPE(llama_state) + REFL_FIELD(log_callback) + REFL_FIELD(log_callback_user_data) + REFL_END -// REFL_TYPE(llama_hparams) -// REFL_FIELD(vocab_only) -// REFL_FIELD(n_vocab) -// REFL_END +REFL_TYPE(llama_hparams) + REFL_FIELD(vocab_only) + REFL_FIELD(n_vocab) + REFL_END REFL_TYPE(llama_cparams) @@ -346,24 +335,21 @@ REFL_TYPE(llama_cparams) REFL_FIELD(n_batch) REFL_END -//REFL_TYPE(llama_layer) -// REFL_FIELD(attn_norm) -// REFL_FIELD(attn_norm_b) -//REFL_END +REFL_TYPE(llama_layer) + REFL_FIELD(attn_norm) + REFL_FIELD(attn_norm_b) +REFL_END -// REFL_TYPE(llama_kv_cell) -// REFL_FIELD(pos) -// REFL_FIELD(delta) -// REFL_END +REFL_TYPE(llama_kv_cell) + REFL_FIELD(pos) + REFL_FIELD(delta) +REFL_END REFL_TYPE(llama_kv_cache) REFL_FIELD(has_shift) REFL_FIELD(head) REFL_END -// REFL_TYPE(llama_vocab) -// REFL_END - REFL_TYPE(e_model) REFL_END @@ -389,29 +375,22 @@ REFL_FIELD( output_norm) REFL_FIELD( output_norm_b) REFL_FIELD( output) -//REFL_FIELD( layers) +REFL_FIELD( layers) REFL_FIELD( n_gpu_layers) -//REFL_FIELD( gguf_kv) unordered map + REFL_FIELD( gguf_kv) //unordered map REFL_FIELD( ctx) REFL_FIELD( buf) -//REFL_FIELD( mapping) std::unique_ptr -//REFL_FIELD( mlock_buf) -//REFL_FIELD( mlock_mmap) -//REFL_FIELD( tensors_by_name) + REFL_FIELD( mapping) //std::unique_ptr +REFL_FIELD( mlock_buf) +REFL_FIELD( mlock_mmap) +REFL_FIELD( tensors_by_name) REFL_FIELD( t_load_us) REFL_FIELD( t_start_us) REFL_END - -REFL_TYPE(llama_hparams) - REFL_END - -//REFL_TYPE(std::vector >) -//REFL_END - REFL_TYPE(llama_vocab) REFL_END @@ -422,7 +401,7 @@ REFL_TYPE(llama_context) REFL_FIELD( cparams) //REFL_FIELD(model) REFL_FIELD(kv_self) -//REFL_FIELD(rng) random numbers + REFL_FIELD(rng) //random numbers REFL_FIELD(has_evaluated_once ) REFL_FIELD( t_start_us) REFL_FIELD( t_load_us) @@ -432,13 +411,13 @@ REFL_FIELD( t_p_eval_us ) REFL_FIELD( n_sample ) REFL_FIELD( n_p_eval ) REFL_FIELD( n_eval ) -//REFL_FIELD( logits) +REFL_FIELD( logits) REFL_FIELD( logits_all ) -//REFL_FIELD( embedding) -//REFL_FIELD( work_buffer) +REFL_FIELD( embedding) +REFL_FIELD( work_buffer) REFL_FIELD( buf_compute) REFL_FIELD( buf_alloc) -//REFL_FIELD( alloc ) +REFL_FIELD( alloc ) #ifdef GGML_USE_METAL REFL_FIELD( ctx_metal ) @@ -450,108 +429,102 @@ REFL_FIELD( ctx_mpi ) #endif REFL_END -// REFL_TYPE(llama_model_loader) -// REFL_FIELD(n_kv) -// REFL_FIELD(n_tensors) -// REFL_END +REFL_TYPE(llama_model_loader) + REFL_FIELD(n_kv) + REFL_FIELD(n_tensors) +REFL_END -// REFL_TYPE(llm_build_context) -// REFL_FIELD(model) -// REFL_FIELD(hparams) -// REFL_END +REFL_TYPE(llm_build_context) +// REFL_FIELD(model) cannot create pointer to reference member ‘llm_build_context::model’ +// REFL_FIELD(hparams) cannot create pointer to reference member ‘llm_build_context::hparams’ +REFL_END -// REFL_TYPE(llm_offload_trie) -// REFL_END +REFL_TYPE(llm_offload_trie) +REFL_END -// REFL_TYPE(llm_symbol) -// REFL_FIELD(prev) -// REFL_END +REFL_TYPE(llm_symbol) + REFL_FIELD(prev) +REFL_END -// REFL_TYPE(llm_bigram_spm) -// REFL_END +REFL_TYPE(llm_bigram_spm) +REFL_END -// REFL_TYPE(llm_tokenizer_spm) -// REFL_END +REFL_TYPE(llm_tokenizer_spm) +REFL_END -// REFL_TYPE(llm_bigram_bpe) -// REFL_END +REFL_TYPE(llm_bigram_bpe) +REFL_END -// REFL_TYPE(llm_tokenizer_bpe) -// REFL_END - - -// REFL_TYPE(fragment_buffer_variant) -// REFL_END - - -// REFL_TYPE(llama_partial_utf8) -// REFL_FIELD(value) -// REFL_FIELD(n_remain) -// REFL_END - - -REFL_TYPE(llama_grammar) -// REFL_FIELD(rules) -// REFL_FIELD(stacks) +REFL_TYPE(llm_tokenizer_bpe) REFL_END -//REFL_TYPE(llama_grammar_candidate) -// REFL_FIELD(index) -// REFL_FIELD(code_points) -//REFL_END +REFL_TYPE(fragment_buffer_variant) +REFL_END -// REFL_TYPE(llama_beam) -// REFL_FIELD(tokens) -// REFL_FIELD(p) -// REFL_END +REFL_TYPE(llama_partial_utf8) + REFL_FIELD(value) + REFL_FIELD(n_remain) +REFL_END -// REFL_TYPE(llama_logit_info) -// REFL_FIELD(logits) -// REFL_FIELD(n_vocab) -// REFL_END - -// REFL_TYPE(llama_beam_search_data) -// REFL_FIELD(ctx) -// REFL_FIELD(n_beams) -// REFL_END - - -// REFL_TYPE(quantize_state_internal) -// REFL_FIELD(model) -// REFL_FIELD(params) -// REFL_END - -// REFL_TYPE(llama_data_context) -// REFL_END +REFL_TYPE(llama_grammar) + REFL_FIELD(rules) + REFL_FIELD(stacks) +REFL_END -// REFL_TYPE(llama_data_buffer_context) -// REFL_FIELD(ptr) -// REFL_END -// REFL_TYPE(llama_data_file_context) -// REFL_FIELD(file) -// REFL_END +REFL_TYPE(llama_grammar_candidate) + REFL_FIELD(index) + REFL_FIELD(code_points) +REFL_END -// // A simple struct with some fields and a function -// // A custom attribute to mark some fields as hidden -struct hidden : refl::attr::usage::field {}; -// // Another struct with some fields and a function, using the custom attribute -// struct Person { -// std::string name; -// int age; -// [[hidden]] std::string password; -// void say_hello() const { -// std::cout << "Hello, I'm " << name << " and I'm " << age << " years old.\n"; -// } -// }; +REFL_TYPE(llama_beam) + REFL_FIELD(tokens) + REFL_FIELD(p) +REFL_END + + +REFL_TYPE(llama_logit_info) + REFL_FIELD(logits) + REFL_FIELD(n_vocab) +REFL_END + +REFL_TYPE(llama_beam_search_data) + REFL_FIELD(ctx) + REFL_FIELD(n_beams) +REFL_END + + +REFL_TYPE(quantize_state_internal) +// REFL_FIELD(model) + REFL_FIELD(params) +REFL_FIELD( n_attention_wv ) +REFL_FIELD( n_feed_forward_w2 ) + REFL_FIELD( i_attention_wv ) + REFL_FIELD( i_feed_forward_w2 ) +REFL_FIELD( n_k_quantized ) +REFL_FIELD( n_fallback ) + +REFL_END + +REFL_TYPE(llama_data_context) +REFL_END + +REFL_TYPE(llama_data_buffer_context) + REFL_FIELD(ptr) +REFL_END + +REFL_TYPE(llama_data_file_context) + REFL_FIELD(file) +REFL_END + // // A generic function to print out the fields of any object template -void print_fields(const T& t) { +void print_fields(const T& ) { //return; // // Get the type descriptor of the object constexpr auto type = refl::reflect();