diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp index 913fb21ea..3925b8970 100644 --- a/src/llama-adapter.cpp +++ b/src/llama-adapter.cpp @@ -1,5 +1,6 @@ #include "llama-adapter.h" +#include "llama-mmap.h" #include "llama-model.h" #include diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 38a55fb2c..4b195eaca 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1,5 +1,7 @@ #include "llama-context.h" +#include "llama-mmap.h" + #include #include #include @@ -504,7 +506,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { auto * buft = ggml_backend_cpu_buffer_type(); // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory - auto * output_dev = lctx.model.dev_output.dev; + auto * output_dev = lctx.model.dev_output(); auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; if (output_dev_host_buft) { buft = output_dev_host_buft; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 90b6c56ed..feffdf0de 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -79,7 +79,7 @@ bool llama_kv_cache_init( ggml_backend_buffer_type_t buft; if (offload) { - auto * dev = model.dev_layer.at(i).dev; + auto * dev = model.dev_layer(i); buft = ggml_backend_dev_buffer_type(dev); } else { buft = ggml_backend_cpu_buffer_type(); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b2f1cf377..95f3be113 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1,14 +1,18 @@ #include "llama-model.h" #include "llama-impl.h" +#include "llama-mmap.h" #include "llama-model-loader.h" +#include "ggml-cpp.h" + #include "unicode.h" // TODO: remove #include #include #include #include +#include #include #include @@ -106,12 +110,273 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } -llama_model::llama_model(const struct llama_model_params & params) : params(params) { +// checks if the weight tensor can be used with the specified buffer type and device +static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + GGML_ASSERT(w != nullptr); + + if (op == GGML_OP_NONE) { + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + switch (op) { + case GGML_OP_GET_ROWS: + { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_get_rows(ctx, w, b); + } break; + case GGML_OP_MUL_MAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } break; + case GGML_OP_MUL_MAT_ID: + { + int n_expert_used = hparams.n_expert_used; + ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_mul_mat_id(ctx, w, b, ids); + } break; + case GGML_OP_ADD: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_add(ctx, a, w); + } break; + case GGML_OP_MUL: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_mul(ctx, a, w); + } break; + case GGML_OP_DIV: + { + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); + op_tensor = ggml_div(ctx, a, w); + } break; + case GGML_OP_ROPE: + { + int n_embd_head = hparams.n_embd_head_v; + int n_head = hparams.n_head(); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_rope_ext( + ctx, a, b, w, + 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ); + + } break; + case GGML_OP_SSM_CONV: + { + // FIXME + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + op_tensor = ggml_ssm_conv(ctx, conv_x, w); + } break; + case GGML_OP_SSM_SCAN: + { + // FIXME + const int64_t d_state = w->ne[0]; + const int64_t d_inner = w->ne[1]; + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 1; + ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); + ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + } break; + case GGML_OP_RWKV_WKV6: + { + // FIXME + const int64_t S = 123; + const int64_t H = 123; + const int64_t n_tokens = 123; + const int64_t n_seqs = 123; + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); + ggml_tensor * tf = w; + ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); + ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); + } break; + case GGML_OP_IM2COL: + { + const int n_embd = hparams.n_embd; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); + op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); + } break; + default: + GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + + return op_supported; } +// lists of buffer types used for each layer +using buft_list_t = std::vector>; + +// find the first buffer type in the list that can use the tensor +static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t & buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { + return cur_buft; + } + } + return nullptr; +} + +// CPU: ACCEL -> CPU extra -> GPU host -> CPU +static buft_list_t make_cpu_buft_list(const std::vector & devices) { + buft_list_t buft_list; + + // add ACCEL buffer types + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + auto * buft = ggml_backend_dev_buffer_type(dev); + // skip + if (buft != ggml_backend_cpu_buffer_type()) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add extra buffer types + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // add a host buffer type + // storing the tensors in a host buffer is useful when the processing of large batches + // is offloaded to a GPU device, since it reduces the time spent on data transfers + // generally, this will be done using the first device in the list + // a better approach would be to handle this on a weight-by-weight basis using the offload_op + // function of the device to determine if it would benefit from being stored in a host buffer + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } + } + + // add the CPU buffer type + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) { + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + } + } + + return buft_list; +} + +// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU +static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) { + buft_list_t buft_list; + + // add the device split buffer type if requested and available + if (split_mode == LLAMA_SPLIT_MODE_ROW) { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type"); + if (ggml_backend_split_buffer_type_fn) { + size_t dev_index = [&]() { + auto * reg = ggml_backend_dev_backend_reg(dev); + for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) { + if (ggml_backend_reg_dev_get(reg, i) == dev) { + return i; + } + } + throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev))); + }(); + auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split); + if (buft != nullptr) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add the device default buffer type + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + + return buft_list; +} + + +struct llama_model::impl { + uint64_t n_elements = 0; + + size_t n_bytes = 0; + + std::string desc_str; + + // model memory mapped files + llama_mmaps mappings; + + // objects representing data potentially being locked in memory + llama_mlocks mlock_bufs; + llama_mlocks mlock_mmaps; + + // contexts where the model tensors metadata is stored + std::vector ctxs; + + // the model memory buffers for the tensor data + std::vector bufs; + + buft_list_t cpu_buft_list; + std::map gpu_buft_list; + + struct layer_dev { + ggml_backend_dev_t dev; + buft_list_t * buft_list; + }; + + layer_dev dev_input = {}; + layer_dev dev_output = {}; + std::vector dev_layer; + +}; + +llama_model::llama_model(const struct llama_model_params & params) : params(params), pimpl(std::make_unique()) { +} + +llama_model::~llama_model() = default; + void llama_model::load_stats(llama_model_loader & ml) { - n_elements = ml.n_elements; - n_bytes = ml.n_bytes; + pimpl->n_elements = ml.n_elements; + pimpl->n_bytes = ml.n_bytes; } void llama_model::load_arch(llama_model_loader & ml) { @@ -972,9 +1237,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: throw std::runtime_error("unsupported model architecture"); } - n_bytes = ml.n_bytes; + pimpl->n_bytes = ml.n_bytes; - desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); + pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); if (hparams.f_max_alibi_bias > 0.0f) { hparams.use_alibi = true; @@ -1694,128 +1959,6 @@ void llama_model::load_vocab(llama_model_loader & ml) { } } -// checks if the weight tensor can be used with the specified buffer type and device -static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { - GGML_ASSERT(w != nullptr); - - if (op == GGML_OP_NONE) { - return true; - } - - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - ggml_context_ptr ctx_ptr { ggml_init(params) }; - if (!ctx_ptr) { - throw std::runtime_error(format("failed to create ggml context")); - } - ggml_context * ctx = ctx_ptr.get(); - - ggml_tensor * op_tensor = nullptr; - - switch (op) { - case GGML_OP_GET_ROWS: - { - ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); - op_tensor = ggml_get_rows(ctx, w, b); - } break; - case GGML_OP_MUL_MAT: - { - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); - op_tensor = ggml_mul_mat(ctx, w, b); - } break; - case GGML_OP_MUL_MAT_ID: - { - int n_expert_used = hparams.n_expert_used; - ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); - ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); - op_tensor = ggml_mul_mat_id(ctx, w, b, ids); - } break; - case GGML_OP_ADD: - { - ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); - op_tensor = ggml_add(ctx, a, w); - } break; - case GGML_OP_MUL: - { - ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); - op_tensor = ggml_mul(ctx, a, w); - } break; - case GGML_OP_DIV: - { - ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); - op_tensor = ggml_div(ctx, a, w); - } break; - case GGML_OP_ROPE: - { - int n_embd_head = hparams.n_embd_head_v; - int n_head = hparams.n_head(); - ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); - ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); - op_tensor = ggml_rope_ext( - ctx, a, b, w, - 0, 0, 0, 0, 0, - 0, 0, 0, 0 - ); - - } break; - case GGML_OP_SSM_CONV: - { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); - op_tensor = ggml_ssm_conv(ctx, conv_x, w); - } break; - case GGML_OP_SSM_SCAN: - { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; - const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); - } break; - case GGML_OP_RWKV_WKV6: - { - // FIXME - const int64_t S = 123; - const int64_t H = 123; - const int64_t n_tokens = 123; - const int64_t n_seqs = 123; - ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); - ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); - ggml_tensor * tf = w; - ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); - ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); - op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); - } break; - case GGML_OP_IM2COL: - { - const int n_embd = hparams.n_embd; - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); - op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); - } break; - default: - GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); - } - - // create a temporary dummy buffer for the weight so that supports_op can check the buffer type - GGML_ASSERT(w->buffer == nullptr); - w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); - bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - ggml_backend_buffer_free(w->buffer); - w->buffer = nullptr; - - return op_supported; -} - bool llama_model::load_tensors(llama_model_loader & ml) { const auto & split_mode = params.split_mode; const auto & n_gpu_layers = params.n_gpu_layers; @@ -1827,20 +1970,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const bool use_mmap_buffer = true; // build a list of buffer types for the CPU and GPU devices - cpu_buft_list = make_cpu_buft_list(); + pimpl->cpu_buft_list = make_cpu_buft_list(devices); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback - buft_list.insert(buft_list.end(), cpu_buft_list.begin(), cpu_buft_list.end()); - gpu_buft_list.emplace(dev, std::move(buft_list)); + buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); + pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); } // calculate the split points - bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_device(), [](float x) { return x == 0.0f; }); - std::vector splits(n_device()); + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); + std::vector splits(n_devices()); if (all_zero) { // default split, by free memory - for (size_t i = 0; i < n_device(); ++i) { + for (size_t i = 0; i < n_devices(); ++i) { ggml_backend_dev_t dev = devices[i]; size_t total; size_t free; @@ -1848,42 +1991,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) { splits[i] = free; } } else { - std::copy(tensor_split, tensor_split + n_device(), splits.begin()); + std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); } // sum and normalize the splits to get the split points float split_sum = 0.0f; - for (size_t i = 0; i < n_device(); ++i) { + for (size_t i = 0; i < n_devices(); ++i) { split_sum += splits[i]; splits[i] = split_sum; } - for (size_t i = 0; i < n_device(); ++i) { + for (size_t i = 0; i < n_devices(); ++i) { splits[i] /= split_sum; } ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); - auto get_layer_buft_list = [&](int il) -> llama_model::layer_dev { + auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { - return {cpu_dev, &cpu_buft_list}; + return {cpu_dev, &pimpl->cpu_buft_list}; } - const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_device(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); + const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); auto * dev = devices.at(layer_gpu); - return {dev, &gpu_buft_list.at(dev)}; + return {dev, &pimpl->gpu_buft_list.at(dev)}; }; // assign the input layer // there is very little benefit to offloading the input layer, so always keep it on the CPU - dev_input = { cpu_dev, &cpu_buft_list }; + pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; // assign the repeating layers to the devices according to the splits - dev_layer.resize(n_layer); + pimpl->dev_layer.resize(n_layer); for (int il = 0; il < n_layer; ++il) { - dev_layer[il] = get_layer_buft_list(il); + pimpl->dev_layer[il] = get_layer_buft_list(il); } + // assign the output layer - dev_output = get_layer_buft_list(n_layer); + pimpl->dev_output = get_layer_buft_list(n_layer); // one ggml context per buffer type int max_n_tensors = ml.n_tensors; @@ -1900,12 +2044,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; + ggml_context * ctx = ggml_init(params); if (!ctx) { throw std::runtime_error(format("failed to create ggml context")); } + ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); + pimpl->ctxs.emplace_back(ctx); + return ctx; } return it->second; @@ -1988,22 +2135,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // select the buffer type for this tensor - llama_model::buft_list_t * buft_list; + buft_list_t * buft_list; switch (info.layer) { case LLM_TENSOR_LAYER_INPUT: - buft_list = dev_input.buft_list; + buft_list = pimpl->dev_input.buft_list; break; case LLM_TENSOR_LAYER_OUTPUT: - buft_list = dev_output.buft_list; + buft_list = pimpl->dev_output.buft_list; break; case LLM_TENSOR_LAYER_REPEATING: - buft_list = dev_layer.at(tn.bid).buft_list; + buft_list = pimpl->dev_layer.at(tn.bid).buft_list; break; default: GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); } - ggml_backend_buffer_type_t buft = select_weight_buft(t_meta, op, *buft_list); + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); if (!buft) { throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); } @@ -3865,8 +4012,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { ml.done_getting_tensors(); - ml.init_mappings(true, use_mlock ? &mlock_mmaps : nullptr); - mappings.reserve(ml.mappings.size()); + ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); + pimpl->mappings.reserve(ml.mappings.size()); // create the backend buffers std::vector> ctx_bufs; @@ -3874,7 +4021,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // Ensure we have enough capacity for the maximum backend buffer we will potentially create const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size(); - bufs.reserve(n_max_backend_buffer); + pimpl->bufs.reserve(n_max_backend_buffer); for (auto & it : ctx_map) { ggml_backend_buffer_type_t buft = it.first; @@ -3915,7 +4062,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } - bufs.emplace_back(buf); + pimpl->bufs.emplace_back(buf); buf_map.emplace(idx, buf); } } @@ -3924,10 +4071,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } - bufs.emplace_back(buf); + pimpl->bufs.emplace_back(buf); if (use_mlock && ggml_backend_buffer_is_host(buf)) { - mlock_bufs.emplace_back(new llama_mlock); - auto & mlock_buf = mlock_bufs.back(); + pimpl->mlock_bufs.emplace_back(new llama_mlock); + auto & mlock_buf = pimpl->mlock_bufs.back(); mlock_buf->init (ggml_backend_buffer_get_base(buf)); mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); } @@ -3936,7 +4083,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } - if (bufs.empty()) { + if (pimpl->bufs.empty()) { throw std::runtime_error("failed to allocate buffer"); } @@ -3964,12 +4111,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // print memory requirements per buffer type - for (auto & buf : bufs) { + for (auto & buf : pimpl->bufs) { LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); } // populate tensors_by_name - for (auto & ctx : ctxs) { + for (auto & ctx : pimpl->ctxs) { for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { tensors_by_name.emplace_back(ggml_get_name(cur), cur); } @@ -3979,14 +4126,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (auto & it : ctx_bufs) { ggml_context * ctx = it.first; auto & bufs = it.second; - if (!ml.load_all_data(ctx, bufs, use_mlock ? &mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { + if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { return false; } } if (use_mmap_buffer) { for (auto & mapping : ml.mappings) { - mappings.emplace_back(std::move(mapping)); + pimpl->mappings.emplace_back(std::move(mapping)); } } @@ -4002,21 +4149,25 @@ std::string llama_model::type_name() const { } std::string llama_model::desc() const { - return desc_str; + return pimpl->desc_str; } size_t llama_model::size() const { - return n_bytes; + return pimpl->n_bytes; } size_t llama_model::max_nodes() const { return std::max(8192, tensors_by_name.size()*5); } -size_t llama_model::n_device() const { +size_t llama_model::n_devices() const { return devices.size(); } +uint64_t llama_model::n_elements() const { + return pimpl->n_elements; +} + void llama_model::print_info() const { const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train); @@ -4093,14 +4244,14 @@ void llama_model::print_info() const { } LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); - if (n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, n_elements*1e-12); - } else if (n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, n_elements*1e-9); - } else if (n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, n_elements*1e-6); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, n_elements*1e-3); + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); } // general kv @@ -4163,6 +4314,14 @@ void llama_model::print_info() const { } } +ggml_backend_dev_t llama_model::dev_layer(int il) const { + return pimpl->dev_layer.at(il).dev; +} + +ggml_backend_dev_t llama_model::dev_output() const { + return pimpl->dev_output.dev; +} + template static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { ggml_init_params params = { @@ -4191,7 +4350,7 @@ static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t d } template -static ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) { +static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) { for (const auto & cur : buft_list) { ggml_backend_dev_t cur_dev = cur.first; ggml_backend_buffer_type_t cur_buft = cur.second; @@ -4205,7 +4364,7 @@ static ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & b ggml_backend_buffer_type_t llama_model::select_buft(int il) const { return ::select_buft( - *dev_layer.at(il).buft_list, + *pimpl->dev_layer.at(il).buft_list, [&](ggml_context * ctx) { ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); @@ -4242,104 +4401,6 @@ std::string llama_model::token_to_piece(llama_token token, bool special) const { return piece; } -// find the first buffer type in the list that can use the tensor -ggml_backend_buffer_type_t llama_model::select_weight_buft(ggml_tensor * tensor, ggml_op op, const llama_model::buft_list_t & buft_list) const { - GGML_ASSERT(!buft_list.empty()); - for (const auto & cur : buft_list) { - ggml_backend_dev_t cur_dev = cur.first; - ggml_backend_buffer_type_t cur_buft = cur.second; - if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { - return cur_buft; - } - } - return nullptr; -} - -// CPU: ACCEL -> CPU extra -> GPU host -> CPU -llama_model::buft_list_t llama_model::make_cpu_buft_list() const { - buft_list_t buft_list; - - // add ACCEL buffer types - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { - auto * buft = ggml_backend_dev_buffer_type(dev); - // skip - if (buft != ggml_backend_cpu_buffer_type()) { - buft_list.emplace_back(dev, buft); - } - } - } - - // add extra buffer types - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(cpu_dev, *extra_bufts); - ++extra_bufts; - } - } - - // add a host buffer type - // storing the tensors in a host buffer is useful when the processing of large batches - // is offloaded to a GPU device, since it reduces the time spent on data transfers - // generally, this will be done using the first device in the list - // a better approach would be to handle this on a weight-by-weight basis using the offload_op - // function of the device to determine if it would benefit from being stored in a host buffer - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); - if (buft) { - buft_list.emplace_back(dev, buft); - break; - } - } - - // add the CPU buffer type - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) { - buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); - } - } - - return buft_list; -} - -llama_model::buft_list_t llama_model::make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) { - buft_list_t buft_list; - - // add the device split buffer type if requested and available - if (split_mode == LLAMA_SPLIT_MODE_ROW) { - ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t) - ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type"); - if (ggml_backend_split_buffer_type_fn) { - size_t dev_index = [&]() { - auto * reg = ggml_backend_dev_backend_reg(dev); - for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) { - if (ggml_backend_reg_dev_get(reg, i) == dev) { - return i; - } - } - throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev))); - }(); - auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split); - if (buft != nullptr) { - buft_list.emplace_back(dev, buft); - } - } - } - - // add the device default buffer type - buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); - - return buft_list; -} - // // interface implementation // @@ -4530,7 +4591,7 @@ uint64_t llama_model_size(const struct llama_model * model) { } uint64_t llama_model_n_params(const struct llama_model * model) { - return model->n_elements; + return model->n_elements(); } bool llama_model_has_encoder(const struct llama_model * model) { diff --git a/src/llama-model.h b/src/llama-model.h index 9b4af5f91..93b99351e 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -4,10 +4,10 @@ #include "llama-arch.h" #include "llama-hparams.h" #include "llama-vocab.h" -#include "llama-mmap.h" - -#include "ggml-cpp.h" +#include +#include +#include #include struct llama_model_loader; @@ -315,53 +315,24 @@ struct llama_model { std::vector layers; + llama_model_params params; + // gguf metadata std::unordered_map gguf_kv; - llama_model_params params; - std::vector rpc_servers; // list of devices used in this model std::vector devices; - // lists of buffer types used for each layer - using buft_list_t = std::vector>; - buft_list_t cpu_buft_list; - std::map gpu_buft_list; - - struct layer_dev { - ggml_backend_dev_t dev; - buft_list_t * buft_list; - }; - - layer_dev dev_input = {}; - layer_dev dev_output = {}; - std::vector dev_layer; - - // contexts where the model tensors metadata is stored - std::vector ctxs; - - // the model memory buffers for the tensor data - std::vector bufs; - - // model memory mapped files - llama_mmaps mappings; - - // objects representing data potentially being locked in memory - llama_mlocks mlock_bufs; - llama_mlocks mlock_mmaps; - // for quantize-stats only std::vector> tensors_by_name; int64_t t_load_us = 0; int64_t t_start_us = 0; - // total number of parameters in the model - uint64_t n_elements = 0; - llama_model(const struct llama_model_params & params); + ~llama_model(); void load_stats (llama_model_loader & ml); void load_arch (llama_model_loader & ml); @@ -376,29 +347,25 @@ struct llama_model { size_t size() const; size_t max_nodes() const; - size_t n_device() const; + size_t n_devices() const; + + // total number of parameters in the model + uint64_t n_elements() const; void print_info() const; + ggml_backend_dev_t dev_layer(int il) const; + ggml_backend_dev_t dev_output() const; + ggml_backend_buffer_type_t select_buft(int il) const; const struct ggml_tensor * get_tensor(const char * name) const; private: - size_t n_bytes = 0; - - std::string desc_str; + struct impl; + std::unique_ptr pimpl; std::string token_to_piece(llama_token token, bool special) const; - - // find the first buffer type in the list that can use the tensor - ggml_backend_buffer_type_t select_weight_buft(ggml_tensor * tensor, ggml_op op, const llama_model::buft_list_t & buft_list) const; - - // CPU: ACCEL -> CPU extra -> GPU host -> CPU - buft_list_t make_cpu_buft_list() const; - - // GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU - buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split); }; const char * llm_type_name(llm_type type); diff --git a/src/llama.cpp b/src/llama.cpp index fd84c16f5..f579ad164 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8045,9 +8045,9 @@ static struct ggml_cgraph * llama_build_graph( const bool full_offload = lctx.model.params.n_gpu_layers > (int) lctx.model.hparams.n_layer; if (ubatch.n_tokens < 32 || full_offload) { if (il != -1 && strcmp(name, "norm") == 0) { - const auto & dev_layer = lctx.model.dev_layer.at(il); + const auto & dev_layer = lctx.model.dev_layer(il); for (auto & backend : lctx.backends) { - if (ggml_backend_get_device(backend.get()) == dev_layer.dev) { + if (ggml_backend_get_device(backend.get()) == dev_layer) { if (ggml_backend_supports_op(backend.get(), cur)) { ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, backend.get()); } @@ -9365,7 +9365,7 @@ struct llama_model * llama_model_load_from_file( LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024); } - int status = llama_model_load(path_model, *model, params); + const int status = llama_model_load(path_model, *model, params); GGML_ASSERT(status <= 0); if (status < 0) { if (status == -1) { @@ -9647,7 +9647,7 @@ struct llama_context * llama_new_context_with_model( // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = - model->n_device() > 1 && + model->n_devices() > 1 && model->params.n_gpu_layers > (int)model->hparams.n_layer && model->params.split_mode == LLAMA_SPLIT_MODE_LAYER && params.offload_kqv;