From 15b193729bea372e59086fa7fd52c21e5520c9e1 Mon Sep 17 00:00:00 2001 From: Holden X Date: Fri, 15 Dec 2023 23:46:51 +0800 Subject: [PATCH] Offloading tensors based on total VRAM budget and offloading policy (#6) * deprecate ffn_b * get tensor offloading levels * wip: split tensor loading * wip: framework of loading sparse model tensors * save and flush gpu alloc buffer * vram budget will fall back to remaining free memory * minor: remove vram safety margin * add options for vram budget; clean old env vars * minor: bugfix --- common/common.cpp | 13 ++ common/common.h | 1 + ggml-cuda.cu | 8 + ggml-cuda.h | 1 + llama.cpp | 514 +++++++++++++++++++++++++++++++++++----------- llama.h | 1 + 6 files changed, 418 insertions(+), 120 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9f8cb33f0..429f313a3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -565,6 +565,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { #else fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); #endif // GGML_USE_CUBLAS + } else if (arg == "--vram-budget") { + if (++i >= argc) { + invalid_param = true; + break; + } +#ifdef GGML_USE_CUBLAS + params.vram_budget_gb = std::stof(argv[i]); +#else + fprintf(stderr, "warning: PowerInfer was compiled without cuBLAS. It is not possible to set a VRAM budget.\n"); +#endif } else if (arg == "--no-mul-mat-q" || arg == "-nommq") { #ifdef GGML_USE_CUBLAS params.mul_mat_q = false; @@ -801,6 +811,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); + printf(" --vram-budget N VRAM budget in GiB (default: -1, -1 = available VRAM)\n"); printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); @@ -895,6 +906,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.n_gpu_layers = params.n_gpu_layers; } mparams.main_gpu = params.main_gpu; + mparams.vram_budget_gb = params.vram_budget_gb; mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; @@ -1402,4 +1414,5 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); + fprintf(stream, "vram_budget: %f # default: -1.0 (all available VRAM)\n", params.vram_budget_gb); } diff --git a/common/common.h b/common/common.h index 7714102ec..204970181 100644 --- a/common/common.h +++ b/common/common.h @@ -64,6 +64,7 @@ struct gpt_params { int32_t n_beams = 0; // if non-zero then use beam search of given width. float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + float vram_budget_gb = -1.0f; // VRAM budget in GB (-1 - use available VRAM) float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor float yarn_beta_fast = 32.0f; // YaRN low correction dim diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 78a52a84d..8ef1d013d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -9338,6 +9338,13 @@ int ggml_cuda_get_device_count() { return device_count; } +size_t ggml_cuda_get_free_memory(int device) { + size_t free, total; + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaMemGetInfo(&free, &total)); + return free; +} + void ggml_cuda_get_device_description(int device, char * description, size_t description_size) { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); @@ -9610,3 +9617,4 @@ ggml_backend_t ggml_backend_cuda_init() { return cuda_backend; } + diff --git a/ggml-cuda.h b/ggml-cuda.h index d253bab0e..6ea73cd0f 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -51,6 +51,7 @@ GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, s GGML_API int ggml_cuda_get_device_count(void); GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size); +GGML_API size_t ggml_cuda_get_free_memory(int device); // backend API GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use diff --git a/llama.cpp b/llama.cpp index 01123d8d9..e2e466d0e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -543,6 +543,38 @@ static llm_arch llm_arch_from_string(const std::string & name) { return LLM_ARCH_UNKNOWN; } +enum tensor_offloading_levels { + TENSOR_NO_OFFLOAD, + TENSOR_OFFLOAD_FFN, + TENSOR_OFFLOAD_ATTN, + TENSOR_OFFLOAD_MLP_PRED, + TENSOR_OFFLOAD_OUTPUT, + TENSOR_OFFLOAD_KV_CACHE, +}; + +tensor_offloading_levels get_offloading_level(llm_tensor tensor) { + switch (tensor) { + case LLM_TENSOR_TOKEN_EMBD: case LLM_TENSOR_TOKEN_EMBD_NORM: case LLM_TENSOR_POS_EMBD: + case LLM_TENSOR_ROPE_FREQS: + return TENSOR_NO_OFFLOAD; + case LLM_TENSOR_OUTPUT: case LLM_TENSOR_OUTPUT_NORM: + return TENSOR_OFFLOAD_OUTPUT; + case LLM_TENSOR_ATTN_Q: case LLM_TENSOR_ATTN_K: case LLM_TENSOR_ATTN_V: + case LLM_TENSOR_ATTN_QKV: case LLM_TENSOR_ATTN_OUT: case LLM_TENSOR_ATTN_NORM: + case LLM_TENSOR_ATTN_NORM_2: case LLM_TENSOR_ATTN_ROT_EMBD: + case LLM_TENSOR_ATTN_Q_NORM: case LLM_TENSOR_ATTN_K_NORM: + return TENSOR_OFFLOAD_ATTN; + case LLM_TENSOR_FFN_GATE: case LLM_TENSOR_FFN_DOWN: case LLM_TENSOR_FFN_UP: + case LLM_TENSOR_FFN_NORM: case LLM_TENSOR_FFN_DOWN_T: + return TENSOR_OFFLOAD_FFN; + case LLM_TENSOR_MLP_PRED_FC1: case LLM_TENSOR_MLP_PRED_FC2: + return TENSOR_OFFLOAD_MLP_PRED; + default: + throw std::runtime_error("unknown tensor category"); + } +} + + // helper to handle gguf constants // usage: // @@ -557,20 +589,20 @@ struct LLM_TN { llm_arch arch; - std::string operator()(llm_tensor tensor) const { - return LLM_TENSOR_NAMES[arch].at(tensor); + std::pair operator()(llm_tensor tensor) const { + return std::make_pair(LLM_TENSOR_NAMES[arch].at(tensor), get_offloading_level(tensor)); } - std::string operator()(llm_tensor tensor, const std::string & suffix) const { - return LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix; + std::pair operator()(llm_tensor tensor, const std::string & suffix) const { + return std::make_pair(LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix, get_offloading_level(tensor)); } - std::string operator()(llm_tensor tensor, int bid) const { - return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid); + std::pair operator()(llm_tensor tensor, int bid) const { + return std::make_pair(::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid), get_offloading_level(tensor)); } - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { - return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; + std::pair operator()(llm_tensor tensor, const std::string & suffix, int bid) const { + return std::make_pair(::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix, get_offloading_level(tensor)); } }; @@ -1915,7 +1947,11 @@ struct llama_model_loader { return tensor; } - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend_type backend) { + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::pair & tn, const std::vector & ne, ggml_backend_type backend) { + return create_tensor(ctx, tn.first, ne, backend); + } + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string &name, const std::vector & ne, ggml_backend_type backend) { struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str()); if (cur == NULL) { @@ -2882,20 +2918,244 @@ struct llama_augmentation_model_loader { } }; -static bool should_offload_mlp_at_layer(int layer_idx) { - char * n_offload = getenv("N_OFFLOAD_MLP"); - if (n_offload == nullptr) { - return false; - } - return layer_idx < atoi(n_offload); -} +struct buffered_tensor_allocator { + llama_model_loader &ml; + ggml_context *ctx; + std::map> alloc_queues; + const size_t vram_budget_bytes; + size_t vram_allocated_bytes = 0; -static bool should_offload_attention_at_layer(int layer_idx) { - char * n_offload = getenv("N_OFFLOAD_ATTN"); - if (n_offload == nullptr) { - return false; + buffered_tensor_allocator(llama_model_loader &ml, ggml_context *ctx, size_t vram_budget_bytes) : ctx(ctx), ml(ml), vram_budget_bytes(vram_budget_bytes) {} + + ggml_tensor * buffered_alloc(const std::string & name, const tensor_offloading_levels &level, const std::vector & ne) { +#if defined(GGML_USE_CUBLAS) + if (level == TENSOR_NO_OFFLOAD || level == TENSOR_OFFLOAD_FFN) { + return ml.create_tensor(ctx, name, ne, GGML_BACKEND_CPU); + } + // Alloc only metadata for GPU tensors + bool no_alloc = ctx->no_alloc; + ggml_set_no_alloc(ctx, true); + ggml_tensor * meta_tensor = ml.create_tensor(ctx, name, ne, GGML_BACKEND_CPU); + ggml_set_no_alloc(ctx, no_alloc); + alloc_queues[level].push_back(meta_tensor); + return meta_tensor; +#else + return ml.create_tensor(ctx, name, ne, GGML_BACKEND_CPU); +#endif } - return layer_idx < atoi(n_offload); + + // For GPU tensors, we need to allocate them in VRAM as much as possible, + // and update the tensor data in-place. If the VRAM budget is exceeded, + // we allocate the tensor in CPU memory. + void flush() { +#if defined(GGML_USE_CUBLAS) + // iterate over offloading priorities + for (int enum_i = TENSOR_OFFLOAD_ATTN; enum_i <= TENSOR_OFFLOAD_KV_CACHE; enum_i ++) { + tensor_offloading_levels level = static_cast(enum_i); + for (ggml_tensor * meta_tensor : alloc_queues[level]) { + size_t tensor_data_size = ggml_nbytes(meta_tensor); + if (vram_allocated_bytes + tensor_data_size > vram_budget_bytes) { + return; + } + // allocate in VRAM + ggml_set_backend(meta_tensor, GGML_BACKEND_GPU); + vram_allocated_bytes += tensor_data_size; + } + } + ml.done_getting_tensors(); +#endif + } +}; + +static void llm_load_sparse_model_tensors( + llama_model_loader & ml, + llama_model & model, + int main_gpu, + long int vram_budget_bytes, + const float * tensor_split, + bool use_mlock, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { + model.t_start_us = ggml_time_us(); + auto & ctx = model.ctx; + auto & hparams = model.hparams; + + size_t ctx_size; + size_t mmapped_size; + ml.calc_sizes(ctx_size, mmapped_size); + LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); + + // create the ggml context + { + model.buf.resize(ctx_size); + if (use_mlock) { + model.mlock_buf.init (model.buf.data); + model.mlock_buf.grow_to(model.buf.size); + } + + struct ggml_init_params params = { + /*.mem_size =*/ model.buf.size, + /*.mem_buffer =*/ model.buf.data, + /*.no_alloc =*/ ml.use_mmap, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + throw std::runtime_error(format("ggml_init() failed")); + } + } + + (void) main_gpu; + + enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU; + enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU; + size_t vram_capacity = 0; + +#ifdef GGML_USE_CUBLAS + if (ggml_cublas_loaded()) { + LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); + ggml_cuda_set_main_device(main_gpu); + + llama_backend_offload = GGML_BACKEND_GPU; + llama_backend_offload_split = GGML_BACKEND_GPU_SPLIT; + } +#elif defined(GGML_USE_CLBLAST) + LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__); + llama_backend_offload = GGML_BACKEND_GPU; + llama_backend_offload_split = GGML_BACKEND_GPU; +#endif + +#if defined(GGML_USE_CUBLAS) + if (vram_budget_bytes < 0) { + // Let it be the rest of VRAM + vram_capacity = ggml_cuda_get_free_memory(main_gpu); + } else { + vram_capacity = vram_budget_bytes; + } +#endif + + buffered_tensor_allocator alloc(ml, ctx, vram_capacity); + auto create_tensor = [&alloc] ( + const std::pair & tn, + const std::vector & ne) -> ggml_tensor * { + return alloc.buffered_alloc(tn.first, tn.second, ne); + }; + + { + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int64_t n_layer = hparams.n_layer; + const int64_t n_vocab = hparams.n_vocab; + + const auto tn = LLM_TN(model.arch); + switch (model.arch) { + case LLM_ARCH_LLAMA: + case LLM_ARCH_REFACT: + { + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + const uint32_t n_ff = hparams.n_ff; + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down_t = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "weight", i), {n_embd, n_ff}); + layer.mlp_pre_w1 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC1, "weight", i), {n_embd, GGML_NE_WILDCARD}); + layer.mlp_pre_w2 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC2, "weight", i), {GGML_NE_WILDCARD, n_ff}); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; + case LLM_ARCH_FALCON: + { + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + const uint32_t n_ff = hparams.n_ff; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).first.c_str()) >= 0) { + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}); + } + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.ffn_down_t = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "weight", i), {n_embd, n_ff}); + layer.mlp_pre_w1 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC1, "weight", i), {n_embd, GGML_NE_WILDCARD}); + layer.mlp_pre_w2 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC2, "weight", i), {GGML_NE_WILDCARD, n_ff}); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; + default: + throw std::runtime_error("unknown architecture"); + } + } + + alloc.flush(); + + // print memory requirements + { + // this is the total memory required to run the inference + size_t mem_required = + ctx_size + + mmapped_size - alloc.vram_allocated_bytes; // weights in VRAM not in memory + + LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); + +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, alloc.vram_allocated_bytes / 1024.0 / 1024.0); +#endif + } + + // populate `tensors_by_name` + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i)); + model.tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + + ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); + + if (progress_callback) { + progress_callback(1.0f, progress_callback_user_data); + } + + model.mapping = std::move(ml.mapping); + + // loading time will be recalculate after the first eval, so + // we take page faults deferred by mmap() into consideration + model.t_load_us = ggml_time_us() - model.t_start_us; + + model.n_gpu_layers = -1; // based on offloading results? } static void llm_load_tensors( @@ -2960,29 +3220,8 @@ static void llm_load_tensors( llama_backend_offload_split = GGML_BACKEND_GPU; #endif - // deprecated - auto ffn_b = [] (ggml_backend_type backend) -> ggml_backend_type { - const bool ffn_offloading = false; - if (ffn_offloading) { - return backend; - } - return GGML_BACKEND_CPU; - }; - // prepare memory for the weights size_t vram_weights = 0; - auto create_tensor = [&] (const std::string & name, const std::vector & ne, ggml_backend_type backend) -> ggml_tensor * { - ggml_tensor * created_tensor = ml.create_tensor(ctx, name, ne, backend); - if (created_tensor == nullptr) { - LLAMA_LOG_ERROR("%s: error: failed to create tensor '%s'\n", __func__, name.c_str()); - return nullptr; - } - if (created_tensor->backend == GGML_BACKEND_GPU || created_tensor->backend == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(created_tensor); - } - return created_tensor; - }; - { const int64_t n_embd = hparams.n_embd; const int64_t n_embd_gqa = hparams.n_embd_gqa(); @@ -2994,7 +3233,7 @@ static void llm_load_tensors( case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: { - model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // output { @@ -3016,8 +3255,15 @@ static void llm_load_tensors( backend_output = GGML_BACKEND_CPU; } - model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } } const uint32_t n_ff = hparams.n_ff; @@ -3027,31 +3273,30 @@ static void llm_load_tensors( model.layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend_type attention_backend = should_offload_attention_at_layer(i) ? llama_backend_offload : GGML_BACKEND_CPU; - const ggml_backend_type mlp_backend = should_offload_mlp_at_layer(i) ? llama_backend_offload : GGML_BACKEND_CPU; - const ggml_backend_type ffn_backend = GGML_BACKEND_CPU; + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT auto & layer = model.layers[i]; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, attention_backend); + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, attention_backend); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, attention_backend); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, attention_backend); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, attention_backend); + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_backend); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, ffn_backend); - if (model.sparse_deriv == GGML_SPARSE_INFERENCE) { - layer.ffn_down_t = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "weight", i), {n_embd, n_ff}, ffn_backend); - layer.mlp_pre_w1 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC1, "weight", i), {n_embd, GGML_NE_WILDCARD}, mlp_backend); - layer.mlp_pre_w2 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC2, "weight", i), {GGML_NE_WILDCARD, n_ff}, mlp_backend); - } else { - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_backend); + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); } - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_backend); - } } break; case LLM_ARCH_BAICHUAN: @@ -3106,11 +3351,11 @@ static void llm_load_tensors( layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend)); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, ffn_b(backend_split)); - layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_b(backend_split)); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split)); + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -3124,7 +3369,7 @@ static void llm_load_tensors( { // TODO: CPU-only for now - model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // output { @@ -3146,41 +3391,56 @@ static void llm_load_tensors( backend_output = GGML_BACKEND_CPU; } - model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); - model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } } const uint32_t n_ff = hparams.n_ff; + const int i_gpu_start = n_layer - n_gpu_layers; + model.layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend_type attention_backend = should_offload_attention_at_layer(i) ? llama_backend_offload : GGML_BACKEND_CPU; - const ggml_backend_type mlp_backend = should_offload_mlp_at_layer(i) ? llama_backend_offload : GGML_BACKEND_CPU; - const ggml_backend_type ffn_backend = GGML_BACKEND_CPU; + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT auto & layer = model.layers[i]; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, attention_backend); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, attention_backend); + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); - if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) { - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, attention_backend); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, attention_backend); + if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).first.c_str()) >= 0) { + layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend); + layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(layer.attn_norm_2); + vram_weights += ggml_nbytes(layer.attn_norm_2_b); + } } - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, attention_backend); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, attention_backend); - // Either ffn_down or ffn_down_t is used, depending on the model type - if (model.sparse_deriv == GGML_SPARSE_INFERENCE) { - layer.ffn_down_t = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "weight", i), {n_embd, n_ff}, ffn_backend); - layer.mlp_pre_w1 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC1, "weight", i), {n_embd, GGML_NE_WILDCARD}, mlp_backend); - layer.mlp_pre_w2 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC2, "weight", i), {GGML_NE_WILDCARD, n_ff}, mlp_backend); - } else { - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_backend); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) + + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); } - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_backend); } } break; case LLM_ARCH_STARCODER: @@ -3242,14 +3502,14 @@ static void llm_load_tensors( layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend)); - layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, ffn_b(backend)); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, ffn_b(backend_split)); - layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, ffn_b(backend)); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split)); - layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, ffn_b(backend)); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -3318,12 +3578,12 @@ static void llm_load_tensors( layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); - layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, ffn_b(backend_split)); - layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, ffn_b(backend)); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split)); - layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, ffn_b(backend)); - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend)); - layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, ffn_b(backend)); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); layer.attn_q_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}, backend); layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend); layer.attn_k_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}, backend); @@ -3392,14 +3652,14 @@ static void llm_load_tensors( layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend)); - layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, ffn_b(backend)); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, ffn_b(backend_split)); - layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, ffn_b(backend)); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split)); - layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, ffn_b(backend)); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -3463,10 +3723,10 @@ static void llm_load_tensors( layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend)); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_b(backend_split)); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split)); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -3538,12 +3798,12 @@ static void llm_load_tensors( layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend)); - layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, ffn_b(backend)); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, ffn_b(backend_split)); - layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_b(backend_split)); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split)); + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -3641,10 +3901,23 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con return true; } - llm_load_tensors( - ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock, - params.progress_callback, params.progress_callback_user_data - ); + if (llama_use_sparse_inference(&model)) { + if (params.n_gpu_layers > 0) { + LLAMA_LOG_WARN("%s: sparse inference ignores n_gpu_layers, you can use --vram-budget option instead\n", __func__); + return false; + } + double vram_budget_bytes = params.vram_budget_gb * 1024.0 * 1024.0 * 1024.0; + llm_load_sparse_model_tensors( + ml, model, params.main_gpu, vram_budget_bytes, params.tensor_split, params.use_mlock, + params.progress_callback, params.progress_callback_user_data + ); + } else { + llm_load_tensors( + ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock, + params.progress_callback, params.progress_callback_user_data + ); + } + } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); return false; @@ -8296,7 +8569,7 @@ static ggml_type get_k_quant_type( return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; }; - if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { + if (name == tn(LLM_TENSOR_OUTPUT, "weight").first) { int nx = tensor->ne[0]; if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) { new_type = GGML_TYPE_Q8_0; @@ -8960,6 +9233,7 @@ struct llama_model_params llama_model_default_params() { struct llama_model_params result = { /*.n_gpu_layers =*/ 0, /*.main_gpu =*/ 0, + /*.vram_budget_gb =*/ -1.0, /*.tensor_split =*/ nullptr, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, diff --git a/llama.h b/llama.h index bd4a416f2..1d73b5223 100644 --- a/llama.h +++ b/llama.h @@ -161,6 +161,7 @@ extern "C" { struct llama_model_params { int32_t n_gpu_layers; // number of layers to store in VRAM int32_t main_gpu; // the GPU that is used for scratch and small tensors + float vram_budget_gb; // VRAM budget in GB, -1 for all available VRAM (for a single GPU) const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) // called with a progress value between 0 and 1, pass NULL to disable