Offloading tensors based on total VRAM budget and offloading policy (#6)
* deprecate ffn_b * get tensor offloading levels * wip: split tensor loading * wip: framework of loading sparse model tensors * save and flush gpu alloc buffer * vram budget will fall back to remaining free memory * minor: remove vram safety margin * add options for vram budget; clean old env vars * minor: bugfix
This commit is contained in:
parent
b89a0b7296
commit
15b193729b
6 changed files with 418 additions and 120 deletions
|
@ -565,6 +565,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
#else
|
||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
|
||||
#endif // GGML_USE_CUBLAS
|
||||
} else if (arg == "--vram-budget") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
params.vram_budget_gb = std::stof(argv[i]);
|
||||
#else
|
||||
fprintf(stderr, "warning: PowerInfer was compiled without cuBLAS. It is not possible to set a VRAM budget.\n");
|
||||
#endif
|
||||
} else if (arg == "--no-mul-mat-q" || arg == "-nommq") {
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
params.mul_mat_q = false;
|
||||
|
@ -801,6 +811,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
|
||||
printf(" --vram-budget N VRAM budget in GiB (default: -1, -1 = available VRAM)\n");
|
||||
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
|
||||
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
|
||||
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
|
||||
|
@ -895,6 +906,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
|||
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
mparams.main_gpu = params.main_gpu;
|
||||
mparams.vram_budget_gb = params.vram_budget_gb;
|
||||
mparams.tensor_split = params.tensor_split;
|
||||
mparams.use_mmap = params.use_mmap;
|
||||
mparams.use_mlock = params.use_mlock;
|
||||
|
@ -1402,4 +1414,5 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
||||
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
|
||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||
fprintf(stream, "vram_budget: %f # default: -1.0 (all available VRAM)\n", params.vram_budget_gb);
|
||||
}
|
||||
|
|
|
@ -64,6 +64,7 @@ struct gpt_params {
|
|||
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||
float vram_budget_gb = -1.0f; // VRAM budget in GB (-1 - use available VRAM)
|
||||
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
||||
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||
|
|
|
@ -9338,6 +9338,13 @@ int ggml_cuda_get_device_count() {
|
|||
return device_count;
|
||||
}
|
||||
|
||||
size_t ggml_cuda_get_free_memory(int device) {
|
||||
size_t free, total;
|
||||
CUDA_CHECK(cudaSetDevice(device));
|
||||
CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return free;
|
||||
}
|
||||
|
||||
void ggml_cuda_get_device_description(int device, char * description, size_t description_size) {
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
|
||||
|
@ -9610,3 +9617,4 @@ ggml_backend_t ggml_backend_cuda_init() {
|
|||
|
||||
return cuda_backend;
|
||||
}
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, s
|
|||
|
||||
GGML_API int ggml_cuda_get_device_count(void);
|
||||
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
||||
GGML_API size_t ggml_cuda_get_free_memory(int device);
|
||||
|
||||
// backend API
|
||||
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
|
||||
|
|
514
llama.cpp
514
llama.cpp
|
@ -543,6 +543,38 @@ static llm_arch llm_arch_from_string(const std::string & name) {
|
|||
return LLM_ARCH_UNKNOWN;
|
||||
}
|
||||
|
||||
enum tensor_offloading_levels {
|
||||
TENSOR_NO_OFFLOAD,
|
||||
TENSOR_OFFLOAD_FFN,
|
||||
TENSOR_OFFLOAD_ATTN,
|
||||
TENSOR_OFFLOAD_MLP_PRED,
|
||||
TENSOR_OFFLOAD_OUTPUT,
|
||||
TENSOR_OFFLOAD_KV_CACHE,
|
||||
};
|
||||
|
||||
tensor_offloading_levels get_offloading_level(llm_tensor tensor) {
|
||||
switch (tensor) {
|
||||
case LLM_TENSOR_TOKEN_EMBD: case LLM_TENSOR_TOKEN_EMBD_NORM: case LLM_TENSOR_POS_EMBD:
|
||||
case LLM_TENSOR_ROPE_FREQS:
|
||||
return TENSOR_NO_OFFLOAD;
|
||||
case LLM_TENSOR_OUTPUT: case LLM_TENSOR_OUTPUT_NORM:
|
||||
return TENSOR_OFFLOAD_OUTPUT;
|
||||
case LLM_TENSOR_ATTN_Q: case LLM_TENSOR_ATTN_K: case LLM_TENSOR_ATTN_V:
|
||||
case LLM_TENSOR_ATTN_QKV: case LLM_TENSOR_ATTN_OUT: case LLM_TENSOR_ATTN_NORM:
|
||||
case LLM_TENSOR_ATTN_NORM_2: case LLM_TENSOR_ATTN_ROT_EMBD:
|
||||
case LLM_TENSOR_ATTN_Q_NORM: case LLM_TENSOR_ATTN_K_NORM:
|
||||
return TENSOR_OFFLOAD_ATTN;
|
||||
case LLM_TENSOR_FFN_GATE: case LLM_TENSOR_FFN_DOWN: case LLM_TENSOR_FFN_UP:
|
||||
case LLM_TENSOR_FFN_NORM: case LLM_TENSOR_FFN_DOWN_T:
|
||||
return TENSOR_OFFLOAD_FFN;
|
||||
case LLM_TENSOR_MLP_PRED_FC1: case LLM_TENSOR_MLP_PRED_FC2:
|
||||
return TENSOR_OFFLOAD_MLP_PRED;
|
||||
default:
|
||||
throw std::runtime_error("unknown tensor category");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// helper to handle gguf constants
|
||||
// usage:
|
||||
//
|
||||
|
@ -557,20 +589,20 @@ struct LLM_TN {
|
|||
|
||||
llm_arch arch;
|
||||
|
||||
std::string operator()(llm_tensor tensor) const {
|
||||
return LLM_TENSOR_NAMES[arch].at(tensor);
|
||||
std::pair<std::string, tensor_offloading_levels> operator()(llm_tensor tensor) const {
|
||||
return std::make_pair(LLM_TENSOR_NAMES[arch].at(tensor), get_offloading_level(tensor));
|
||||
}
|
||||
|
||||
std::string operator()(llm_tensor tensor, const std::string & suffix) const {
|
||||
return LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix;
|
||||
std::pair<std::string, tensor_offloading_levels> operator()(llm_tensor tensor, const std::string & suffix) const {
|
||||
return std::make_pair(LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix, get_offloading_level(tensor));
|
||||
}
|
||||
|
||||
std::string operator()(llm_tensor tensor, int bid) const {
|
||||
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid);
|
||||
std::pair<std::string, tensor_offloading_levels> operator()(llm_tensor tensor, int bid) const {
|
||||
return std::make_pair(::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid), get_offloading_level(tensor));
|
||||
}
|
||||
|
||||
std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
|
||||
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
|
||||
std::pair<std::string, tensor_offloading_levels> operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
|
||||
return std::make_pair(::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix, get_offloading_level(tensor));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1915,7 +1947,11 @@ struct llama_model_loader {
|
|||
return tensor;
|
||||
}
|
||||
|
||||
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, ggml_backend_type backend) {
|
||||
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::pair<std::string, tensor_offloading_levels> & tn, const std::vector<int64_t> & ne, ggml_backend_type backend) {
|
||||
return create_tensor(ctx, tn.first, ne, backend);
|
||||
}
|
||||
|
||||
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string &name, const std::vector<int64_t> & ne, ggml_backend_type backend) {
|
||||
struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str());
|
||||
|
||||
if (cur == NULL) {
|
||||
|
@ -2882,20 +2918,244 @@ struct llama_augmentation_model_loader {
|
|||
}
|
||||
};
|
||||
|
||||
static bool should_offload_mlp_at_layer(int layer_idx) {
|
||||
char * n_offload = getenv("N_OFFLOAD_MLP");
|
||||
if (n_offload == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return layer_idx < atoi(n_offload);
|
||||
}
|
||||
struct buffered_tensor_allocator {
|
||||
llama_model_loader &ml;
|
||||
ggml_context *ctx;
|
||||
std::map<tensor_offloading_levels, std::vector<ggml_tensor *>> alloc_queues;
|
||||
const size_t vram_budget_bytes;
|
||||
size_t vram_allocated_bytes = 0;
|
||||
|
||||
static bool should_offload_attention_at_layer(int layer_idx) {
|
||||
char * n_offload = getenv("N_OFFLOAD_ATTN");
|
||||
if (n_offload == nullptr) {
|
||||
return false;
|
||||
buffered_tensor_allocator(llama_model_loader &ml, ggml_context *ctx, size_t vram_budget_bytes) : ctx(ctx), ml(ml), vram_budget_bytes(vram_budget_bytes) {}
|
||||
|
||||
ggml_tensor * buffered_alloc(const std::string & name, const tensor_offloading_levels &level, const std::vector<int64_t> & ne) {
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
if (level == TENSOR_NO_OFFLOAD || level == TENSOR_OFFLOAD_FFN) {
|
||||
return ml.create_tensor(ctx, name, ne, GGML_BACKEND_CPU);
|
||||
}
|
||||
// Alloc only metadata for GPU tensors
|
||||
bool no_alloc = ctx->no_alloc;
|
||||
ggml_set_no_alloc(ctx, true);
|
||||
ggml_tensor * meta_tensor = ml.create_tensor(ctx, name, ne, GGML_BACKEND_CPU);
|
||||
ggml_set_no_alloc(ctx, no_alloc);
|
||||
alloc_queues[level].push_back(meta_tensor);
|
||||
return meta_tensor;
|
||||
#else
|
||||
return ml.create_tensor(ctx, name, ne, GGML_BACKEND_CPU);
|
||||
#endif
|
||||
}
|
||||
return layer_idx < atoi(n_offload);
|
||||
|
||||
// For GPU tensors, we need to allocate them in VRAM as much as possible,
|
||||
// and update the tensor data in-place. If the VRAM budget is exceeded,
|
||||
// we allocate the tensor in CPU memory.
|
||||
void flush() {
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// iterate over offloading priorities
|
||||
for (int enum_i = TENSOR_OFFLOAD_ATTN; enum_i <= TENSOR_OFFLOAD_KV_CACHE; enum_i ++) {
|
||||
tensor_offloading_levels level = static_cast<tensor_offloading_levels>(enum_i);
|
||||
for (ggml_tensor * meta_tensor : alloc_queues[level]) {
|
||||
size_t tensor_data_size = ggml_nbytes(meta_tensor);
|
||||
if (vram_allocated_bytes + tensor_data_size > vram_budget_bytes) {
|
||||
return;
|
||||
}
|
||||
// allocate in VRAM
|
||||
ggml_set_backend(meta_tensor, GGML_BACKEND_GPU);
|
||||
vram_allocated_bytes += tensor_data_size;
|
||||
}
|
||||
}
|
||||
ml.done_getting_tensors();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
static void llm_load_sparse_model_tensors(
|
||||
llama_model_loader & ml,
|
||||
llama_model & model,
|
||||
int main_gpu,
|
||||
long int vram_budget_bytes,
|
||||
const float * tensor_split,
|
||||
bool use_mlock,
|
||||
llama_progress_callback progress_callback,
|
||||
void * progress_callback_user_data) {
|
||||
model.t_start_us = ggml_time_us();
|
||||
auto & ctx = model.ctx;
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
size_t ctx_size;
|
||||
size_t mmapped_size;
|
||||
ml.calc_sizes(ctx_size, mmapped_size);
|
||||
LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0);
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
model.buf.resize(ctx_size);
|
||||
if (use_mlock) {
|
||||
model.mlock_buf.init (model.buf.data);
|
||||
model.mlock_buf.grow_to(model.buf.size);
|
||||
}
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ model.buf.size,
|
||||
/*.mem_buffer =*/ model.buf.data,
|
||||
/*.no_alloc =*/ ml.use_mmap,
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
throw std::runtime_error(format("ggml_init() failed"));
|
||||
}
|
||||
}
|
||||
|
||||
(void) main_gpu;
|
||||
|
||||
enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU;
|
||||
enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU;
|
||||
size_t vram_capacity = 0;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (ggml_cublas_loaded()) {
|
||||
LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
|
||||
ggml_cuda_set_main_device(main_gpu);
|
||||
|
||||
llama_backend_offload = GGML_BACKEND_GPU;
|
||||
llama_backend_offload_split = GGML_BACKEND_GPU_SPLIT;
|
||||
}
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__);
|
||||
llama_backend_offload = GGML_BACKEND_GPU;
|
||||
llama_backend_offload_split = GGML_BACKEND_GPU;
|
||||
#endif
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
if (vram_budget_bytes < 0) {
|
||||
// Let it be the rest of VRAM
|
||||
vram_capacity = ggml_cuda_get_free_memory(main_gpu);
|
||||
} else {
|
||||
vram_capacity = vram_budget_bytes;
|
||||
}
|
||||
#endif
|
||||
|
||||
buffered_tensor_allocator alloc(ml, ctx, vram_capacity);
|
||||
auto create_tensor = [&alloc] (
|
||||
const std::pair<std::string, tensor_offloading_levels> & tn,
|
||||
const std::vector<int64_t> & ne) -> ggml_tensor * {
|
||||
return alloc.buffered_alloc(tn.first, tn.second, ne);
|
||||
};
|
||||
|
||||
{
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
|
||||
const auto tn = LLM_TN(model.arch);
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_REFACT:
|
||||
{
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down_t = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "weight", i), {n_embd, n_ff});
|
||||
layer.mlp_pre_w1 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC1, "weight", i), {n_embd, GGML_NE_WILDCARD});
|
||||
layer.mlp_pre_w2 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC2, "weight", i), {GGML_NE_WILDCARD, n_ff});
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_FALCON:
|
||||
{
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
|
||||
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
|
||||
|
||||
if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).first.c_str()) >= 0) {
|
||||
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
|
||||
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd});
|
||||
}
|
||||
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
layer.ffn_down_t = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "weight", i), {n_embd, n_ff});
|
||||
layer.mlp_pre_w1 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC1, "weight", i), {n_embd, GGML_NE_WILDCARD});
|
||||
layer.mlp_pre_w2 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC2, "weight", i), {GGML_NE_WILDCARD, n_ff});
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
}
|
||||
|
||||
alloc.flush();
|
||||
|
||||
// print memory requirements
|
||||
{
|
||||
// this is the total memory required to run the inference
|
||||
size_t mem_required =
|
||||
ctx_size +
|
||||
mmapped_size - alloc.vram_allocated_bytes; // weights in VRAM not in memory
|
||||
|
||||
LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
|
||||
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
||||
LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, alloc.vram_allocated_bytes / 1024.0 / 1024.0);
|
||||
#endif
|
||||
}
|
||||
|
||||
// populate `tensors_by_name`
|
||||
for (int i = 0; i < ml.n_tensors; ++i) {
|
||||
struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i));
|
||||
model.tensors_by_name.emplace_back(ggml_get_name(cur), cur);
|
||||
}
|
||||
|
||||
ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL);
|
||||
|
||||
if (progress_callback) {
|
||||
progress_callback(1.0f, progress_callback_user_data);
|
||||
}
|
||||
|
||||
model.mapping = std::move(ml.mapping);
|
||||
|
||||
// loading time will be recalculate after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
model.t_load_us = ggml_time_us() - model.t_start_us;
|
||||
|
||||
model.n_gpu_layers = -1; // based on offloading results?
|
||||
}
|
||||
|
||||
static void llm_load_tensors(
|
||||
|
@ -2960,29 +3220,8 @@ static void llm_load_tensors(
|
|||
llama_backend_offload_split = GGML_BACKEND_GPU;
|
||||
#endif
|
||||
|
||||
// deprecated
|
||||
auto ffn_b = [] (ggml_backend_type backend) -> ggml_backend_type {
|
||||
const bool ffn_offloading = false;
|
||||
if (ffn_offloading) {
|
||||
return backend;
|
||||
}
|
||||
return GGML_BACKEND_CPU;
|
||||
};
|
||||
|
||||
// prepare memory for the weights
|
||||
size_t vram_weights = 0;
|
||||
auto create_tensor = [&] (const std::string & name, const std::vector<int64_t> & ne, ggml_backend_type backend) -> ggml_tensor * {
|
||||
ggml_tensor * created_tensor = ml.create_tensor(ctx, name, ne, backend);
|
||||
if (created_tensor == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: error: failed to create tensor '%s'\n", __func__, name.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
if (created_tensor->backend == GGML_BACKEND_GPU || created_tensor->backend == GGML_BACKEND_GPU_SPLIT) {
|
||||
vram_weights += ggml_nbytes(created_tensor);
|
||||
}
|
||||
return created_tensor;
|
||||
};
|
||||
|
||||
{
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
@ -2994,7 +3233,7 @@ static void llm_load_tensors(
|
|||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_REFACT:
|
||||
{
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||
|
||||
// output
|
||||
{
|
||||
|
@ -3016,8 +3255,15 @@ static void llm_load_tensors(
|
|||
backend_output = GGML_BACKEND_CPU;
|
||||
}
|
||||
|
||||
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||
|
||||
if (backend_norm == GGML_BACKEND_GPU) {
|
||||
vram_weights += ggml_nbytes(model.output_norm);
|
||||
}
|
||||
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
|
||||
vram_weights += ggml_nbytes(model.output);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
|
@ -3027,31 +3273,30 @@ static void llm_load_tensors(
|
|||
model.layers.resize(n_layer);
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
const ggml_backend_type attention_backend = should_offload_attention_at_layer(i) ? llama_backend_offload : GGML_BACKEND_CPU;
|
||||
const ggml_backend_type mlp_backend = should_offload_mlp_at_layer(i) ? llama_backend_offload : GGML_BACKEND_CPU;
|
||||
const ggml_backend_type ffn_backend = GGML_BACKEND_CPU;
|
||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, attention_backend);
|
||||
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, attention_backend);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, attention_backend);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, attention_backend);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, attention_backend);
|
||||
layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||
layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_backend);
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, ffn_backend);
|
||||
if (model.sparse_deriv == GGML_SPARSE_INFERENCE) {
|
||||
layer.ffn_down_t = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "weight", i), {n_embd, n_ff}, ffn_backend);
|
||||
layer.mlp_pre_w1 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC1, "weight", i), {n_embd, GGML_NE_WILDCARD}, mlp_backend);
|
||||
layer.mlp_pre_w2 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC2, "weight", i), {GGML_NE_WILDCARD, n_ff}, mlp_backend);
|
||||
} else {
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_backend);
|
||||
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
|
||||
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
|
||||
ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
|
||||
}
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_backend);
|
||||
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BAICHUAN:
|
||||
|
@ -3106,11 +3351,11 @@ static void llm_load_tensors(
|
|||
layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, ffn_b(backend_split));
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_b(backend_split));
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split));
|
||||
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
|
@ -3124,7 +3369,7 @@ static void llm_load_tensors(
|
|||
{
|
||||
// TODO: CPU-only for now
|
||||
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||
|
||||
// output
|
||||
{
|
||||
|
@ -3146,41 +3391,56 @@ static void llm_load_tensors(
|
|||
backend_output = GGML_BACKEND_CPU;
|
||||
}
|
||||
|
||||
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
|
||||
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
|
||||
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||
|
||||
if (backend_norm == GGML_BACKEND_GPU) {
|
||||
vram_weights += ggml_nbytes(model.output_norm);
|
||||
vram_weights += ggml_nbytes(model.output_norm_b);
|
||||
}
|
||||
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
|
||||
vram_weights += ggml_nbytes(model.output);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
const ggml_backend_type attention_backend = should_offload_attention_at_layer(i) ? llama_backend_offload : GGML_BACKEND_CPU;
|
||||
const ggml_backend_type mlp_backend = should_offload_mlp_at_layer(i) ? llama_backend_offload : GGML_BACKEND_CPU;
|
||||
const ggml_backend_type ffn_backend = GGML_BACKEND_CPU;
|
||||
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, attention_backend);
|
||||
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, attention_backend);
|
||||
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
||||
|
||||
if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) {
|
||||
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, attention_backend);
|
||||
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, attention_backend);
|
||||
if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).first.c_str()) >= 0) {
|
||||
layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend);
|
||||
layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights += ggml_nbytes(layer.attn_norm_2);
|
||||
vram_weights += ggml_nbytes(layer.attn_norm_2_b);
|
||||
}
|
||||
}
|
||||
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, attention_backend);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, attention_backend);
|
||||
// Either ffn_down or ffn_down_t is used, depending on the model type
|
||||
if (model.sparse_deriv == GGML_SPARSE_INFERENCE) {
|
||||
layer.ffn_down_t = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "weight", i), {n_embd, n_ff}, ffn_backend);
|
||||
layer.mlp_pre_w1 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC1, "weight", i), {n_embd, GGML_NE_WILDCARD}, mlp_backend);
|
||||
layer.mlp_pre_w2 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC2, "weight", i), {GGML_NE_WILDCARD, n_ff}, mlp_backend);
|
||||
} else {
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_backend);
|
||||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
|
||||
ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) +
|
||||
ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
|
||||
}
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_backend);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER:
|
||||
|
@ -3242,14 +3502,14 @@ static void llm_load_tensors(
|
|||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, ffn_b(backend_split));
|
||||
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
|
||||
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split));
|
||||
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, ffn_b(backend));
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
|
@ -3318,12 +3578,12 @@ static void llm_load_tensors(
|
|||
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, ffn_b(backend_split));
|
||||
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split));
|
||||
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, ffn_b(backend));
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
|
||||
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
|
||||
layer.attn_q_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}, backend);
|
||||
layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend);
|
||||
layer.attn_k_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}, backend);
|
||||
|
@ -3392,14 +3652,14 @@ static void llm_load_tensors(
|
|||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, ffn_b(backend_split));
|
||||
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
|
||||
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split));
|
||||
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, ffn_b(backend));
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
|
@ -3463,10 +3723,10 @@ static void llm_load_tensors(
|
|||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_b(backend_split));
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split));
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
|
@ -3538,12 +3798,12 @@ static void llm_load_tensors(
|
|||
layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, ffn_b(backend));
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, ffn_b(backend_split));
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, ffn_b(backend_split));
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, ffn_b(backend_split));
|
||||
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
|
@ -3641,10 +3901,23 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con
|
|||
return true;
|
||||
}
|
||||
|
||||
llm_load_tensors(
|
||||
ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
|
||||
params.progress_callback, params.progress_callback_user_data
|
||||
);
|
||||
if (llama_use_sparse_inference(&model)) {
|
||||
if (params.n_gpu_layers > 0) {
|
||||
LLAMA_LOG_WARN("%s: sparse inference ignores n_gpu_layers, you can use --vram-budget option instead\n", __func__);
|
||||
return false;
|
||||
}
|
||||
double vram_budget_bytes = params.vram_budget_gb * 1024.0 * 1024.0 * 1024.0;
|
||||
llm_load_sparse_model_tensors(
|
||||
ml, model, params.main_gpu, vram_budget_bytes, params.tensor_split, params.use_mlock,
|
||||
params.progress_callback, params.progress_callback_user_data
|
||||
);
|
||||
} else {
|
||||
llm_load_tensors(
|
||||
ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
|
||||
params.progress_callback, params.progress_callback_user_data
|
||||
);
|
||||
}
|
||||
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
|
||||
return false;
|
||||
|
@ -8296,7 +8569,7 @@ static ggml_type get_k_quant_type(
|
|||
return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2;
|
||||
};
|
||||
|
||||
if (name == tn(LLM_TENSOR_OUTPUT, "weight")) {
|
||||
if (name == tn(LLM_TENSOR_OUTPUT, "weight").first) {
|
||||
int nx = tensor->ne[0];
|
||||
if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
|
@ -8960,6 +9233,7 @@ struct llama_model_params llama_model_default_params() {
|
|||
struct llama_model_params result = {
|
||||
/*.n_gpu_layers =*/ 0,
|
||||
/*.main_gpu =*/ 0,
|
||||
/*.vram_budget_gb =*/ -1.0,
|
||||
/*.tensor_split =*/ nullptr,
|
||||
/*.progress_callback =*/ nullptr,
|
||||
/*.progress_callback_user_data =*/ nullptr,
|
||||
|
|
1
llama.h
1
llama.h
|
@ -161,6 +161,7 @@ extern "C" {
|
|||
struct llama_model_params {
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||
float vram_budget_gb; // VRAM budget in GB, -1 for all available VRAM (for a single GPU)
|
||||
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
||||
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue