Online GPU slicing (#11)

* move gpu slicing python code into a module

* remove dead code in exporting gpu split

* streamline solver and export with one entrypoint

* new powerinfer.py module

* wip: invoke Python to generate gpu split on the fly

* wip: load gpu split on demand

* wip: new gpu split file format

* wip: generate and load new gpu idx format

* wip: generate and load gpu index on the fly

* minor: calculate total VRAM offloading via FFN splitting

* add option to disble gpu index

* bugfix

* wip: bug fix for segment fault

* bugfix

* bugfix and testing

* temporary fix for neuron factor in solving

* fix: generated gpu idx path

* Update README about gpu index
This commit is contained in:
Holden X 2023-12-20 10:09:43 +08:00 committed by GitHub
parent ded0613bd4
commit bb486b88e1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 419 additions and 481 deletions

View file

@ -71,6 +71,7 @@ And new features coming soon:
```bash
git clone https://github.com/SJTU-IPADS/PowerInfer
cd PowerInfer
pip install -r requirements.txt # install Python helpers' dependencies
```
### Build
In order to build PowerInfer you have two different options. These commands are supposed to be run from the root directory of the project.
@ -89,7 +90,8 @@ cmake --build build --config Release
## Model Weights
PowerInfer models are stored in a special format called *PowerInfer GGUF* based on GGUF format, consisting of both LLM weights and predictor weights. You can download PowerInfer GGUF weights from Hugging Face or convert them from the original model weights and predictor weights.
PowerInfer models are stored in a special format called *PowerInfer GGUF* based on GGUF format, consisting of both LLM weights and predictor weights.
You can obtain PowerInfer GGUF weights at `*.powerinfer.gguf` as well as profiled model activation statistics under `activation/` for 'hot'-neuron offloading from each Hugging Face model repo under "PowerInfer GGUF Format" column. You can also convert them from the original model weights and predictor weights.
| Base Model | PowerInfer GGUF Format | Original Model | Predictor |
|------------|------------------|----------------|---------------------|
@ -102,14 +104,16 @@ PowerInfer models are stored in a special format called *PowerInfer GGUF* based
For CPU-only and CPU-GPU hybrid inference with all available VRAM, you can use the following instructions to run PowerInfer:
```bash
./build/bin/main -m /PATH/TO/MODEL -n $output_token_count -t $thread_num -p $prompt
```
If you want to limit the VRAM usage of GPU:
```bash
./build/bin/main -m /PATH/TO/MODEL -n $output_token_count -t $thread_num -p $prompt --vram-budget $vram_gb
./build/bin/main -m /PATH/TO/MODEL -n $output_token_count -t $thread_num -p $prompt
# ./build/bin/main -m ./ReluFalcon-40B-PowerInfer-GGUF/falcon-40b-relu.q4.powerinfer.gguf -n 128 -t 8 -p "Once upon a time"
```
As for now, it requires an offline-generated "GPU index" file to split FFNs on GPU. And we found these files are hard to maintain and distribute. We will ship automatic FFN split based on VRAM capacity via [#11](https://github.com/SJTU-IPADS/PowerInfer/pull/11) very soon.
If you want to limit the VRAM usage of GPU:
```bash
./build/bin/main -m /PATH/TO/MODEL -n $output_token_count -t $thread_num -p $prompt --vram-budget $vram_gb
# ./build/bin/main -m ./ReluLLaMA-7B-PowerInfer-GGUF/llama-7b-relu.powerinfer.gguf -n 128 -t 8 -p "Once upon a time" --vram-budget 8
```
Under CPU-GPU hybrid inference, PowerInfer will automatically offload all dense activation blocks to GPU and split FFN on GPU if possible.
## Evaluation
@ -119,6 +123,13 @@ As for now, it requires an offline-generated "GPU index" file to split FFNs on G
PowerInfer achieves up to 11x and 8x speedup for FP16 and INT4 models!
## FAQs
1. What if I encountered `CUDA_ERROR_OUT_OF_MEMORY`?
- You can try to run with `--reset-gpu-index` argument to rebuild GPU index for this model to avoid any stale cache.
- Due to our current implementation, model offloading might not be accurate as expected. You can try with `--vram-budget` with a slightly lower value or `--disable-gpu-index` to disable FFN offloading.
2. What if...
- Issues are welcomed! Please feel free to open an issue and attach your running environment and running parameters. We will try our best to help you.
## TODOs
We will release the code and data in the following order, please stay tuned!
@ -130,7 +141,7 @@ We will release the code and data in the following order, please stay tuned!
- [ ] Support Metal for Mac
- [ ] Release code for OPT models
- [ ] Release predictor training code
- [ ] Support online split for FFN network
- [x] Support online split for FFN network
- [ ] Support Multi-GPU

View file

@ -471,12 +471,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.lora_base = argv[i];
} else if (arg == "--gpu-index") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.gpu_index = argv[i];
} else if (arg == "--reset-gpu-index") {
params.reset_gpu_index = true;
} else if (arg == "--disable-gpu-index") {
params.disale_gpu_index = true;
} else if (arg == "--mmproj") {
if (++i >= argc) {
invalid_param = true;
@ -910,6 +908,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.reset_gpu_index = params.reset_gpu_index;
mparams.disable_gpu_index = params.disale_gpu_index;
return mparams;
}
@ -968,24 +968,6 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
return std::make_tuple(nullptr, nullptr);
}
if (llama_use_sparse_inference(model)) {
fprintf(stderr, "%s: postprocessing PowerInfer model '%s'\n", __func__, params.model.c_str());
if (!params.gpu_index.empty()) {
int err = llama_model_apply_gpu_idx_from_file(model, params.gpu_index.c_str(), true);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply mlp adapter\n", __func__);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
}
}
if (llama_model_apply_augmentation(model) != 0) {
fprintf(stderr, "%s: error: failed to apply augmentation\n", __func__);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
}
}
auto cparams = llama_context_params_from_gpt_params(params);
llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
@ -1357,7 +1339,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
}
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "gpu_index: %s\n", params.gpu_index.c_str());
fprintf(stream, "reset_gpu_index: %s\n", params.reset_gpu_index ? "true" : "false");
fprintf(stream, "disable_gpu_index: %s\n", params.disale_gpu_index? "true": "false");
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);

View file

@ -91,7 +91,8 @@ struct gpt_params {
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter
std::string gpu_index = ""; // sparse activation mlp adapter path
bool reset_gpu_index = false; // refresh the gpu index file
bool disale_gpu_index = false; // disable loading gpu index and splitting ffn
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line

View file

@ -48,12 +48,11 @@ int main(int argc, char ** argv) {
params.n_threads = std::atoi(argv[6]);
}
if (argc >= 8) {
params.gpu_index = argv[7];
}
// For testing purposes, we always reset the GPU index
params.reset_gpu_index = true;
printf("params: model = %s, prompt = %s, n_parallel = %d, n_len = %d, n_gpu_layers = %d, n_threads = %d, gpu_index = %s\n",
params.model.c_str(), params.prompt.c_str(), n_parallel, n_len, n_gpu_layers, params.n_threads, params.gpu_index.c_str());
printf("params: model = %s, prompt = %s, n_parallel = %d, n_len = %d, n_gpu_layers = %d, n_threads = %d, reset_gpu_index = true\n",
params.model.c_str(), params.prompt.c_str(), n_parallel, n_len, n_gpu_layers, params.n_threads);
if (params.prompt.empty()) {
params.prompt = "Hello my name is";
@ -76,21 +75,6 @@ int main(int argc, char ** argv) {
return 1;
}
if (!params.gpu_index.empty()) {
int err = llama_model_apply_gpu_idx_from_file(model, params.gpu_index.c_str(), true);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply mlp adapter\n", __func__);
llama_free_model(model);
return 1;
}
}
if (llama_model_apply_augmentation(model) != 0) {
fprintf(stderr, "%s: error: failed to apply model augmentation\n", __func__);
llama_free_model(model);
return 1;
}
// tokenize the prompt
std::vector<llama_token> tokens_list;

8
ggml.c
View file

@ -17497,7 +17497,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
}
const int n_threads = cplan->n_threads;
#ifdef LLAMA_CUBLAS
#ifdef GGML_USE_CUBLAS
struct ggml_compute_state_shared state_shared = {
/*.cgraph =*/ cgraph,
/*.cgraph_plan =*/ cplan,
@ -17534,7 +17534,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
.ith = j,
.shared = &state_shared,
};
#ifdef LLAMA_CUBLAS
#ifdef GGML_USE_CUBLAS
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread_hybrid, &workers[j]);
#else
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
@ -17551,7 +17551,8 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
const int64_t perf_start_time_us = ggml_perf_time_us();
// this is a work thread too
#ifdef LLAMA_CUBLAS
#ifdef GGML_USE_CUBLAS
int compute_status = (size_t) ggml_graph_compute_thread_hybrid(&workers[0]);
#else
int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
@ -19590,7 +19591,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
sparse_deriv = GGML_DENSE_INFERENCE;
} else if (strncmp(magic, GGUF_POWERINFER_MAGIC, sizeof(magic)) == 0) {
sparse_deriv = GGML_SPARSE_INFERENCE;
fprintf(stderr, "%s: PowerInfer derived model detected. Sparse inference will be used.\n", __func__);
} else {
fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic);
fclose(file);

View file

@ -74,6 +74,9 @@ class Keys:
class PowerInfer:
SPARSE_THRESHOLD = "powerinfer.sparse_threshold"
class Split:
VRAM_CAPACITY = "split.vram_capacity"
#
# recommended mapping of model tensor names for storage in gguf
@ -385,6 +388,9 @@ class GGMLQuantizationType(IntEnum):
Q5_K = 13
Q6_K = 14
Q8_K = 15
I8 = 16,
I16 = 17
I32 = 18,
class GGUFEndian(IntEnum):

323
llama.cpp
View file

@ -61,6 +61,7 @@
#include <cstdio>
#include <cstring>
#include <ctime>
#include <libgen.h>
#include <forward_list>
#include <fstream>
#include <functional>
@ -216,6 +217,8 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_UNKNOWN, "unknown" },
};
enum llm_kv {
@ -266,6 +269,8 @@ enum llm_kv {
LLM_KV_TOKENIZER_RWKV,
LLM_KV_SPARSE_THRESHOLD,
LLM_KV_SPLIT_VRAM_CAPACITY,
};
static std::map<llm_kv, std::string> LLM_KV_NAMES = {
@ -316,6 +321,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_SPARSE_THRESHOLD, "powerinfer.sparse_threshold" },
{ LLM_KV_SPLIT_VRAM_CAPACITY, "split.vram_capacity" },
};
struct LLM_KV {
@ -756,9 +763,10 @@ struct llama_buffer {
struct llama_file {
// use FILE * so we don't have to re-open the file to mmap
FILE * fp;
std::string fname;
size_t size;
llama_file(const char * fname, const char * mode) {
llama_file(const char * fname, const char * mode): fname(fname) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
@ -1367,7 +1375,7 @@ struct llama_vocab {
}
};
struct llama_mlp_model_loader;
struct llama_gpu_split_loader;
struct llama_augmentation_model_loader;
struct llama_model {
@ -1405,7 +1413,7 @@ struct llama_model {
std::unique_ptr<llama_mmap> mapping;
// aux model loaders for dynamically loaded/transformed model weights
std::unique_ptr<struct llama_mlp_model_loader> mlp_model_loader;
std::unique_ptr<struct llama_gpu_split_loader> mlp_model_loader;
std::unique_ptr<struct llama_augmentation_model_loader> aug_model_loader;
// objects representing data potentially being locked in memory
@ -2632,30 +2640,28 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
}
struct llama_mlp_model_loader {
struct llama_gpu_split_loader {
int n_tensors = 0;
size_t n_bytes = 0; // tensor data bytes
const std::string fname;
llama_file file;
int fver;
bool use_mmap = false; // only supports mmap yet
std::unique_ptr<llama_mmap> mapping;
struct ggml_context * ctx_meta = nullptr;
llama_mlp_model_loader(const std::string & fname, bool use_mmap) : fname(fname), use_mmap(use_mmap), file(fname.c_str(), "rb") {
llama_model_loader * idx_loader;
size_t vram_required = 0;
llama_gpu_split_loader(const std::string & fname, bool use_mmap) : fname(fname), use_mmap(use_mmap) {
GGML_ASSERT(use_mmap);
// verify magic and version
uint32_t magic = file.read_u32();
// TODO: assert on file magic once we have a stable format
GGML_ASSERT(magic == 0xDEADBEEF && "invalid file magic" || true);
idx_loader = new llama_model_loader(fname, use_mmap);
GGUF_GET_KEY(idx_loader->ctx_gguf, vram_required, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_NAMES[LLM_KV_SPLIT_VRAM_CAPACITY]);
printf("loaded gpu_idx, vram_required: %ld\n", vram_required);
fver = file.read_u32();
GGML_ASSERT(fver == 1 && "unsupported file version");
n_tensors = file.read_u32();
n_tensors = idx_loader->n_tensors;
// allocate memadata/data for mlp tensors
// TODO: support allocating buffer for tensor data (when mmap is not used)
@ -2667,138 +2673,43 @@ struct llama_mlp_model_loader {
/*.no_alloc =*/ true,
};
ctx_meta = ggml_init(params);
}
// memory-map the mlp weights file
mapping.reset(new llama_mmap(&file, /* prefetch */ 0, ggml_is_numa()));
bool check_vram_allocable(size_t vram_budget) {
return vram_budget >= vram_required;
}
int apply_tensors_to_base_model(llama_model * model) {
int n_layers = model->layers.size();
// TODO: assert fp is at the end of headers
if (n_tensors != model -> layers.size() * 2) {
LLAMA_LOG_ERROR("%s: error: the number of mlp adapters does not match the layer of model\n", __func__);
if (n_tensors != n_layers * 2) {
LLAMA_LOG_ERROR("%s: error: the number of gpu splits does not match the layer of model\n", __func__);
return 1;
}
LLAMA_LOG_INFO("%s: applying gpu_idx adapter from '%s' - please wait ...\n", __func__, fname.c_str());
const int64_t t_start_mlp_us = ggml_time_us();
for (llama_layer &model_layer : model -> layers) {
ggml_tensor *mlp_fc1_tensor = load_mlp_tensor_from_stream();
ggml_tensor *mlp_fc2_tensor = load_mlp_tensor_from_stream();
#ifdef GGML_USE_CUBLAS
// ggml_set_backend(mlp_fc1_tensor, GGML_BACKEND_GPU);
// ggml_cuda_transform_tensor(mlp_fc1_tensor->data, mlp_fc1_tensor);
// gpu bucket to GPU
ggml_set_backend(mlp_fc2_tensor, GGML_BACKEND_GPU);
ggml_cuda_transform_tensor(mlp_fc2_tensor->data, mlp_fc2_tensor);
#endif // GGML_USE_CUBLAS
if (mlp_fc1_tensor == nullptr || mlp_fc2_tensor == nullptr) {
LLAMA_LOG_ERROR("%s: error: failed to load mlp tensors\n", __func__);
for (int il = 0; il < n_layers; il++) {
llama_layer &model_layer = model->layers[il];
ggml_tensor * gpu_idx = idx_loader->get_tensor_meta(il*2);
ggml_tensor * gpu_bucket = idx_loader->get_tensor_meta(il*2+1);
if (gpu_idx == nullptr || gpu_bucket == nullptr) {
LLAMA_LOG_ERROR("%s: error: failed to load gpu index or bucket\n", __func__);
return 1;
}
// load model layer and check dimensions
// ggml_tensor *model_up_t = model_layer.ffn_up;
// GGML_ASSERT(model_up_t != nullptr);
// if (model_up_t->ne[0] != mlp_fc1_tensor->ne[0] ||
// model_up_t->ne[1] != mlp_fc2_tensor->ne[1]) {
// LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64
// " and %" PRId64
// ");"
// " are you sure that this adapter is for this model?\n",
// __func__, model_up_t->ne[0], mlp_fc1_tensor->ne[1]);
// return 1;
// }
// GGML_ASSERT(model_layer.mlp_pre_w1 == nullptr && model_layer.mlp_pre_w2 == nullptr);
model_layer.gpu_idx = mlp_fc1_tensor;
model_layer.gpu_bucket = mlp_fc2_tensor;
int *data1 = (int *)mlp_fc1_tensor->data;
int *data2 = (int *)mlp_fc2_tensor->data;
LLAMA_LOG_INFO(".");
model_layer.gpu_idx = idx_loader->create_tensor_for(ctx_meta, gpu_idx, GGML_BACKEND_CPU);
model_layer.gpu_bucket = idx_loader->create_tensor_for(ctx_meta, gpu_bucket, GGML_BACKEND_GPU);
}
llama_progress_callback cb = [](float progress, void *ctx) {
LLAMA_LOG_INFO(".");
};
idx_loader->load_all_data(ctx_meta, cb, nullptr, nullptr);
const int64_t t_mlp_us = ggml_time_us() - t_start_mlp_us;
LLAMA_LOG_INFO(" done (%.2f ms)\n", t_mlp_us / 1000.0);
return 0;
}
// Consumes the stream and returns a new mlp tensor.
// Returns nullptr on error.
// TODO: mmap mlp model file
ggml_tensor *load_mlp_tensor_from_stream() {
uint32_t n_dims = file.read_u32();
uint32_t name_length = file.read_u32();
uint32_t ftype = file.read_u32();
uint32_t ne[2] = {1, 1};
for (int i = 0; i < n_dims; ++i) {
ne[i] = file.read_u32();
}
std::string tensor_name;
{
char buf[1024];
file.read_raw(buf, name_length);
tensor_name = std::string(buf, name_length);
}
// const std::string mlp_suffix = ".mlp";
// size_t pos = tensor_name.rfind(mlp_suffix);
// if (pos == std::string::npos) {
// LLAMA_LOG_ERROR("%s: error: '%s' is not a mlp tensor\n", __func__,
// tensor_name.c_str());
// return nullptr;
// }
// std::string mlp_type = tensor_name.substr(pos + mlp_suffix.length());
// std::string base_name = tensor_name;
// base_name.erase(pos);
// LLAMA_LOG_INFO("%s: %s => %s (mlp type %s) (", __func__, tensor_name.c_str(),
// base_name.c_str(), mlp_type.c_str());
// for (int i = 0; i < n_dims; ++i) {
// LLAMA_LOG_INFO("%d ", ne[i]);
// }
// LLAMA_LOG_INFO(")\n");
// LLAMA_LOG_INFO("tensor name %s\n", tensor_name.c_str());
// create ggml tensor
ggml_type wtype;
switch (ftype) {
case 0:
wtype = GGML_TYPE_F32;
break;
case 1:
wtype = GGML_TYPE_F16;
break;
case 18:
wtype = GGML_TYPE_I32;
break;
default: {
LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n", __func__, ftype);
return nullptr;
}
}
ggml_tensor *mlp_tensor;
// if (n_dims != 2) {
// LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
// return nullptr;
// }
mlp_tensor = ggml_new_tensor_2d(ctx_meta, wtype, ne[0], ne[1]);
// ggml_set_name(mlp_tensor, "");
// load tensor data
size_t offset = file.tell();
size_t tensor_data_size = ggml_nbytes(mlp_tensor);
offset = (offset + 31) & -32;
file.seek(offset, SEEK_SET);
// point to the mmaped mlp model file
mlp_tensor -> data = (void *) (static_cast<char *>(mapping -> addr) + offset);
file.seek(tensor_data_size, SEEK_CUR);
return mlp_tensor;
}
};
// to dynamically load/transform llama model weights
@ -2815,8 +2726,8 @@ struct llama_augmentation_model_loader {
// const int64_t ggml_aux_tensor_size = 4 * (100 * 100 + 5120*40*4 * ggml_tensor_overhead() + (int64_t)13824*5120*40*4);
int model_layer = model->layers.size();
int ffn_dim = model->layers[0].ffn_up->ne[1];
const size_t ggml_aux_tensor_size = 4 * (100 * 100 + model_layer*ffn_dim*sizeof(float) * ggml_tensor_overhead() );
printf("augmentation buffer: %ld\n", ggml_aux_tensor_size);
const size_t ggml_aux_tensor_size = 4 * (model_layer*ffn_dim*sizeof(float)*2+ model_layer*ffn_dim*sizeof(float) * ggml_tensor_overhead() );
struct ggml_init_params params = {
/*.mem_size =*/ ggml_aux_tensor_size,
/*.mem_buffer =*/ nullptr,
@ -2868,37 +2779,29 @@ struct llama_augmentation_model_loader {
#endif
}
void slice_ffn_mat_to_gpu(llama_layer & layer) {
size_t slice_ffn_mat_to_gpu(llama_layer & layer) {
std::vector<uint8_t> work_buffer;
ggml_cgraph * tmp_sum_gf = ggml_new_graph(aux_ctx);
ggml_tensor * gpu_idx = layer.gpu_idx;
// calculate the size of tensor to be copied
ggml_tensor * sum_t = ggml_sum(aux_ctx, gpu_idx);
ggml_build_forward_expand(tmp_sum_gf, sum_t);
ggml_graph_compute_helper(work_buffer, tmp_sum_gf, 2);
int64_t gpu_rows = *ggml_get_data_i32(sum_t);
int64_t gpu_index_len = gpu_idx->ne[0];
// ggml_tensor * gpu_bucket = ggml_new_tensor_1d(aux_ctx, GGML_TYPE_I32, gpu_rows);
// make bucket a reverse index back to unstriped mat
// int32_t * pbucket_data = (int32_t *)gpu_bucket->data;
// for (int i = 0; i < gpu_index_len; i++) {
// if (ggml_get_data_i32(gpu_idx)[i] == 0) {
// continue;
// }
// *pbucket_data = i;
// ++pbucket_data;
// }
// layer.gpu_bucket = gpu_bucket;
ggml_tensor *gpu_bucket = layer.gpu_bucket;
size_t offloaded_bytes = 0;
layer.ffn_gate_gpu = create_striped_mat_to_gpu(layer.ffn_gate, gpu_bucket);
layer.ffn_up_gpu = create_striped_mat_to_gpu(layer.ffn_up, gpu_bucket);
layer.ffn_down_gpu = create_striped_mat_to_gpu(layer.ffn_down_t, gpu_bucket);
if (layer.ffn_gate_gpu) {
offloaded_bytes += ggml_nbytes(layer.ffn_gate_gpu);
}
if (layer.ffn_up_gpu) {
offloaded_bytes += ggml_nbytes(layer.ffn_up_gpu);
}
if (layer.ffn_down_gpu) {
offloaded_bytes += ggml_nbytes(layer.ffn_down_gpu);
}
return offloaded_bytes;
}
int apply_augmentation_to_base_model(llama_model * model) {
size_t offload_ffn_split(llama_model * model) {
LLAMA_LOG_INFO("%s: applying augmentation to model - please wait ...\n", __func__);
const int64_t t_start_aug_us = ggml_time_us();
std::vector<uint8_t> work_buffer;
@ -2910,6 +2813,7 @@ struct llama_augmentation_model_loader {
#endif
// load gpu_idx and slice mat to gpu
size_t offloaded_bytes = 0;
for (llama_layer &model_layer : model -> layers) {
// gpu_idx load
if (model_layer.gpu_idx == NULL && model_layer.gpu_bucket == NULL) {
@ -2919,12 +2823,12 @@ struct llama_augmentation_model_loader {
ggml_tensor * gpu_bucket = ggml_new_tensor_1d(aux_ctx, GGML_TYPE_I32, 0);
model_layer.gpu_bucket = gpu_bucket;
}
slice_ffn_mat_to_gpu(model_layer);
offloaded_bytes += slice_ffn_mat_to_gpu(model_layer);
LLAMA_LOG_INFO(".");
}
LLAMA_LOG_INFO(" done (%.2f ms)\n", (ggml_time_us() - t_start_aug_us) / 1000.0);
return 0;
return offloaded_bytes;
}
};
@ -2957,7 +2861,7 @@ struct buffered_tensor_allocator {
// For GPU tensors, we need to allocate them in VRAM as much as possible,
// and update the tensor data in-place. If the VRAM budget is exceeded,
// we allocate the tensor in CPU memory.
void flush() {
size_t flush() {
#if defined(GGML_USE_CUBLAS)
// iterate over offloading priorities
for (int enum_i = TENSOR_OFFLOAD_ATTN; enum_i <= TENSOR_OFFLOAD_KV_CACHE; enum_i ++) {
@ -2965,7 +2869,7 @@ struct buffered_tensor_allocator {
for (ggml_tensor * meta_tensor : alloc_queues[level]) {
size_t tensor_data_size = ggml_nbytes(meta_tensor);
if (vram_allocated_bytes + tensor_data_size > vram_budget_bytes) {
return;
return vram_allocated_bytes;
}
// allocate in VRAM
ggml_set_backend(meta_tensor, GGML_BACKEND_GPU);
@ -2974,15 +2878,83 @@ struct buffered_tensor_allocator {
}
ml.done_getting_tensors();
#endif
return vram_allocated_bytes;
}
};
static bool load_gpu_split_from_split_file(llama_model & model, std::string split_path, size_t vram_budget) {
llama_gpu_split_loader loader(split_path, true);
return loader.check_vram_allocable(vram_budget)
&& loader.apply_tensors_to_base_model(&model) == 0;
}
static bool llm_load_gpu_split_with_budget(llama_model_loader & ml, llama_model & model, size_t vram_allocatable_bytes, bool no_cache) {
const char * model_path = ml.file.fname.c_str();
std::string cached_split_path = std::string(model_path) + ".generated.gpuidx";
const char * model_basedir = dirname(const_cast<char *>(model_path));
// Load GPU split from previously generated cache
if (access(cached_split_path.c_str(), F_OK) == 0 && !no_cache) {
if (load_gpu_split_from_split_file(model, cached_split_path, vram_allocatable_bytes)) {
return true;
}
LLAMA_LOG_ERROR("%s: error: failed to apply previously generated gpu split from '%s'\n", __func__, cached_split_path.c_str());
}
// Generate GPU split
std::string activation_path = std::string(model_basedir) + "/activation";
if (access(activation_path.c_str(), F_OK) != 0) {
LLAMA_LOG_ERROR("%s: error: activation files under '%s' not found\n", __func__, activation_path.c_str());
return false;
}
// Calculate solver parameters
ggml_tensor * ffn_up = model.layers[0].ffn_up;
ggml_tensor * ffn_gate = model.layers[0].ffn_gate;
int slice_size = ffn_up->ne[1] * ggml_type_size(ffn_up->type) / ggml_blck_size(ffn_up->type);
// For model arch with FFN gate, the gate is also sliced, otherwise only the up and down matrices are sliced
int vram_bytes_per_slice = slice_size * (ffn_gate ? 4.5 : 2); // TODO: why 4.5, not 3?
int neuron_cap = floor((double)vram_allocatable_bytes / vram_bytes_per_slice) * 4;
LLAMA_LOG_INFO("invoking powerinfer Python module to generate gpu split for %.2f MiB of VRAM\n", vram_allocatable_bytes / 1024.0 / 1024.0);
std::stringstream command_ss;
command_ss << "python3 -m powerinfer"
<< " --activation " << activation_path
<< " --layer " << model.hparams.n_layer
<< " --neuron " << ffn_up->ne[1]
<< " --capacity " << neuron_cap
<< " --vram-capacity " << vram_allocatable_bytes
<< " --output " << cached_split_path;
if (system(command_ss.str().c_str()) != 0 || access(cached_split_path.c_str(), F_OK) != 0) {
LLAMA_LOG_ERROR("%s: error: failed to generate gpu split\n", __func__);
return false;
}
return load_gpu_split_from_split_file(model, cached_split_path, vram_allocatable_bytes);
}
static void llm_load_gpu_split(llama_model_loader & ml, llama_model & model, size_t vram_budget_bytes, bool no_cache, bool no_offload) {
#if defined(GGML_USE_CUBLAS)
if (vram_budget_bytes >= 512ull * 1024 * 1024 && !no_offload) {
vram_budget_bytes -= 512ull * 1024 * 1024; // leave 512 MiB as a safety margin
if (!llm_load_gpu_split_with_budget(ml, model, vram_budget_bytes, no_cache)) {
LLAMA_LOG_ERROR("%s: error: failed to generate gpu split, an empty one will be used\n", __func__);
}
}
#endif
// Apply GPU index and split FFNs to GPU
size_t ffn_offloaded_bytes = llama_model_offload_ffn_split(&model);
LLAMA_LOG_INFO("%s: offloaded %.2f MiB of FFN weights to GPU\n", __func__, ffn_offloaded_bytes / 1024.0 / 1024.0);
}
static void llm_load_sparse_model_tensors(
llama_model_loader & ml,
llama_model & model,
int main_gpu,
long int vram_budget_bytes,
const float * tensor_split,
bool reset_gpu_index,
bool disable_ffn_split,
bool use_mlock,
llama_progress_callback progress_callback,
void * progress_callback_user_data) {
@ -3131,19 +3103,20 @@ static void llm_load_sparse_model_tensors(
}
}
alloc.flush();
size_t vram_allocated_bytes = alloc.flush();
GGML_ASSERT(vram_allocated_bytes < vram_capacity);
// print memory requirements
{
// this is the total memory required to run the inference
size_t mem_required =
ctx_size +
mmapped_size - alloc.vram_allocated_bytes; // weights in VRAM not in memory
mmapped_size - vram_allocated_bytes; // weights in VRAM not in memory
LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, alloc.vram_allocated_bytes / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_allocated_bytes / 1024.0 / 1024.0);
#endif
}
@ -3161,11 +3134,14 @@ static void llm_load_sparse_model_tensors(
model.mapping = std::move(ml.mapping);
// Offload FFN segments to GPU if possible
llm_load_gpu_split(ml, model, vram_capacity - vram_allocated_bytes, reset_gpu_index, disable_ffn_split);
// loading time will be recalculate after the first eval, so
// we take page faults deferred by mmap() into consideration
model.t_load_us = ggml_time_us() - model.t_start_us;
model.n_gpu_layers = -1; // based on offloading results?
model.n_gpu_layers = -1; // TODO: based on offloading results, by category?
}
static void llm_load_tensors(
@ -3893,6 +3869,10 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con
try {
llama_model_loader ml(fname, params.use_mmap);
if (ml.sparse_deriv == GGML_SPARSE_INFERENCE) {
LLAMA_LOG_INFO("%s: PowerInfer model loaded. Sparse inference will be used.\n", __func__);
}
model.hparams.vocab_only = params.vocab_only;
model.sparse_deriv = ml.sparse_deriv;
@ -3918,8 +3898,8 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con
}
double vram_budget_bytes = params.vram_budget_gb * 1024.0 * 1024.0 * 1024.0;
llm_load_sparse_model_tensors(
ml, model, params.main_gpu, vram_budget_bytes, params.tensor_split, params.use_mlock,
params.progress_callback, params.progress_callback_user_data
ml, model, params.main_gpu, vram_budget_bytes, params.reset_gpu_index, params.disable_gpu_index,
params.use_mlock, params.progress_callback, params.progress_callback_user_data
);
} else {
llm_load_tensors(
@ -9671,24 +9651,19 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
}
int llama_model_apply_gpu_idx_from_file(struct llama_model * model, const char * path_mlp, bool use_mmap) {
llama_mlp_model_loader * mlp_ml = new llama_mlp_model_loader(path_mlp, use_mmap);
llama_gpu_split_loader * mlp_ml = new llama_gpu_split_loader(path_mlp, use_mmap);
if (mlp_ml -> apply_tensors_to_base_model(model) > 0) {
LLAMA_LOG_ERROR("%s: failed to apply mlp adapter\n", __func__);
LLAMA_LOG_ERROR("%s: failed to apply gpu split\n", __func__);
return 1;
}
model -> mlp_model_loader = std::unique_ptr<llama_mlp_model_loader>(mlp_ml);
model -> mlp_model_loader = std::unique_ptr<llama_gpu_split_loader>(mlp_ml);
return 0;
}
// Apply postprocessing steps for PowerInfer derived models
int llama_model_apply_augmentation(struct llama_model * model) {
size_t llama_model_offload_ffn_split(struct llama_model * model) {
llama_augmentation_model_loader * aug_ml = new llama_augmentation_model_loader(model);
if (aug_ml -> apply_augmentation_to_base_model(model) > 0) {
LLAMA_LOG_ERROR("%s: failed to apply augmentation adapter\n", __func__);
return 1;
}
model -> aug_model_loader = std::unique_ptr<llama_augmentation_model_loader>(aug_ml);
return 0;
size_t offloaded_bytes = aug_ml->offload_ffn_split(model);
return offloaded_bytes;
}
int llama_get_kv_cache_token_count(const struct llama_context * ctx) {

View file

@ -173,6 +173,8 @@ extern "C" {
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool reset_gpu_index; // force reset of the GPU index
bool disable_gpu_index; // bypass the GPU index and FFN split
};
struct llama_context_params {
@ -347,7 +349,7 @@ extern "C" {
const char * path_mlp,
bool use_mmap);
LLAMA_API int llama_model_apply_augmentation(struct llama_model * model);
LLAMA_API size_t llama_model_offload_ffn_split(struct llama_model * model);
//
// KV cache

View file

View file

@ -0,0 +1,43 @@
import argparse
from .solver import solve_gpu_split
from .export_split import export_split
if __name__ == "__main__":
# Set up command line arguments
parser = argparse.ArgumentParser(description='Optimize neuron activation based on VRAM capacity and other parameters.')
parser.add_argument('--activation', type=str, required=True, help='Path to the directory containing activation data.')
parser.add_argument('--neuron', type=int, default=8192*4, help='Total number of neurons in the network.')
parser.add_argument('--capacity', type=int, default=int(8192*4*32*0.1), help='Total VRAM capacity for the model.')
parser.add_argument('--layer', type=int, default=59, help='Total number of layers in the neural network.')
parser.add_argument('--vram-capacity', type=int, help='Total VRAM capacity (Bytes) available for splitting')
parser.add_argument('--batch', type=int, default=256, help='Batch size for processing.')
parser.add_argument('--threshold', type=int, default=0, help='Threshold for splitting a layer across multiple GPUs.')
parser.add_argument('--output', type=str, required=True, help='File path for the output pickle file.')
args = parser.parse_args()
print("solver args:", args)
solved = solve_gpu_split(
activation_path=args.activation,
neuron=args.neuron,
capacity=args.capacity,
layer=args.layer,
batch=args.batch,
threshold=args.threshold,
)
print(f"solved: {solved}, total neurons: {sum(solved)}")
export_split(
activations_path=args.activation,
output_path=args.output,
solved_list=solved,
vram_capacity=args.vram_capacity
)
print(f"Exported to {args.output}")

View file

@ -0,0 +1,70 @@
import argparse
import pickle
import gguf
from gguf.constants import GGMLQuantizationType
from gguf.gguf_writer import GGUFWriter
import torch
from pathlib import Path
import os
import struct
import numpy as np
def load_activation_weights(models_base: Path):
# TODO: might need a specification file to indicate which models to load.
# But for now, let's assume it is a plain directory of activation_{0, ... , n_layers - 1}.pt
*_, files = next(os.walk(models_base))
return [torch.load(models_base / f"activation_{i}.pt") for i in range(len(files))]
def append_gpu_idx(gguf: GGUFWriter, i_layer: int, activation, select_count) -> None:
_, indices = torch.topk(activation, k=int(select_count))
gpu_idx = torch.zeros_like(activation)
gpu_idx[indices] = 1
gpu_idx = gpu_idx.numpy().astype(np.int32)
key = f"blk.{i_layer}.gpu_idx"
print(
f"{key} => {key} {gpu_idx.shape} {gpu_idx.dtype} {gpu_idx.nbytes/1024/1024} MiB"
)
gguf.add_tensor(
name=key,
tensor=gpu_idx,
raw_shape=gpu_idx.shape[::-1],
raw_dtype=GGMLQuantizationType.I32,
)
indices = indices.numpy().astype(np.int32)
gpu_bucket = np.sort(indices)
key = f"blk.{i_layer}.gpu_bucket"
print(
f"{key} => {key} {gpu_bucket.shape} {gpu_bucket.dtype} {gpu_bucket.nbytes/1024/1024} MiB"
)
gguf.add_tensor(
name=key,
tensor=gpu_bucket,
raw_shape=gpu_bucket.shape[::-1],
raw_dtype=GGMLQuantizationType.I32,
)
def export_split(activations_path: str, output_path: str, solved_list: list[int], vram_capacity: int):
predictors = load_activation_weights(Path(activations_path)) # predictor => activation acount
gguf_out = GGUFWriter(output_path, "generic.gpu_index")
for i, (activation, selected_count) in enumerate(zip(predictors, solved_list)):
append_gpu_idx(gguf_out, i, activation, selected_count)
# set kvs
gguf_out.add_block_count(len(predictors))
# TODO: better to save the actual capacity that split neurons require
gguf_out.add_uint64(gguf.Keys.Split.VRAM_CAPACITY, vram_capacity)
gguf_out.write_header_to_file()
gguf_out.write_kv_data_to_file()
gguf_out.write_tensors_to_file()
gguf_out.close()
# post-process: write another unique file header to distinguish from the origianl GGUF file
with open(output_path, "r+b") as fout:
POWERINFER_MAGIC = int.from_bytes(b"PWRI", "little")
fout.write(struct.pack("<I", POWERINFER_MAGIC))
fout.write(struct.pack("<I", 3))
print(f"exported GPU index to {output_path}")

View file

@ -0,0 +1,90 @@
#!/usr/bin/env python
# coding=utf-8
import argparse
from cvxopt.glpk import ilp
import numpy as np
from cvxopt import matrix
import torch
import pickle
def solve_gpu_split(
activation_path: str,
neuron: int,
capacity: int,
layer: int,
batch: int,
threshold: int,
):
# Processing activation data
values = []
for i in range(layer):
# Load and sort activation data for each layer
freq = torch.load(f"{activation_path}/activation_{i}.pt")
freq, _ = torch.sort(freq, descending=True)
freq = freq * -1.0
freq = freq.view(-1, batch)
freq = freq.sum(dim=1)
freq = freq.tolist()
values += freq
# Padding zero values for additional constraints
for i in range(layer):
values += [0.0]
c = np.array(values, dtype=float)
c = matrix(c)
# Setting capacity and neuron count per batch
CAP = capacity
CAP = int(CAP / batch)
neuron = int(neuron / batch)
coeff = []
h = []
# Constraint 1: Total neuron activation constraint
lst = []
for i in range(neuron * layer):
lst.append(1)
for i in range(layer):
lst.append(0)
coeff.append(lst)
h.append(CAP)
# Constraint 2: Threshold constraint for GPU split per layer
for i in range(layer):
lst = [0] * (neuron * layer + layer)
for j in range(neuron):
lst[i * neuron + j] = -1
lst[neuron * layer + i] = int(threshold / batch)
coeff.append(lst)
h.append(0)
# Constraint 3: Upper bound on neuron activations
for i in range(layer):
lst = [0] * (neuron * layer + layer)
for j in range(neuron):
lst[i * neuron + j] = 1
lst[neuron * layer + i] = -1000000 # Arbitrary large negative number as an upper bound
coeff.append(lst)
h.append(0)
# Convert lists to matrix format for ILP solver
coeff = np.array(coeff, dtype=float)
G = matrix(coeff)
h = np.array(h, dtype=float)
h = matrix(h)
# Define the set of integer and binary variables
I = set(range(neuron * layer + layer))
B = set()
# Solving the ILP problem
(status, x) = ilp(c, G, h, None, None, B, I, options={'tm_lim' : 30000}) # with 30s timeout
print(f"ILP Status: {status}")
ans = list(x)
print(f"Total Activation Units: {sum(ans)}")
aligned_lst = []
for i in range(layer):
aligned_lst.append(sum(ans[i * neuron:i * neuron + neuron] * batch))
return aligned_lst

View file

@ -0,0 +1,20 @@
[build-system]
requires = [
"flit_core >=3.2,<4",
]
build-backend = "flit_core.buildapi"
[project]
name = "powerinfer"
authors = [
{name = "Holden", email = "hodlenx@gmail.com"},
]
requires-python = ">=3.9"
classifiers = ["License :: OSI Approved :: MIT License"]
version="0.0.1"
description="powerinfer.py: Python helpers for PowerInfer LLM inference engine"
dependencies = [
"torch>=2",
"cvxopt==1.3.2"
]

View file

@ -1,3 +1,4 @@
numpy==1.24.4
sentencepiece==0.1.98
-e ./gguf-py
-e ./powerinfer-py

View file

@ -1,142 +0,0 @@
#!/usr/bin/env python3
import argparse
import torch
import torch.nn as tnn
from pathlib import Path
import os
import re
import struct
from typing import Any, BinaryIO
import numpy as np
import pickle
class ReluMLP(tnn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(ReluMLP, self).__init__()
self.fc1 = tnn.Linear(input_dim, hidden_dim, bias=False)
self.relu = tnn.ReLU()
self.fc2 = tnn.Linear(hidden_dim, output_dim, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def _load_mlp_model(model_file: Path):
model = torch.load(model_file)
# hidden_size, input_size = model.get("fc1.weight").shape
# output_size, _ = model.get("fc2.weight").shape
# mlp = ReluMLP(input_size, hidden_size, output_size)
# mlp.load_state_dict(model)
return model
def load_mlp_predictors(models_base: Path):
# TODO: might need a specification file to indicate which models to load.
# But for now, let's assume it is a plain directory of models_{0, ... , n_layers - 1}.pt
*_, files = next(os.walk(models_base))
return [_load_mlp_model(models_base / f"activation_{i}.pt") for i in range(len(files))]
def write_file_header(fout: BinaryIO, n_tensors: int) -> None:
fout.write(b"gglp"[::-1]) # magic (GGml mLP)
fout.write(struct.pack("i", 1)) # file version
# TODO: If we found we need more common parameters, we can add them here.
fout.write(struct.pack("i", n_tensors))
def write_tensor_header(
fout: BinaryIO, key: str, shape: tuple[int, ...], dtype: np.dtype
) -> None:
_NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1, "int32": 18}
bkey = key.encode("utf-8")
fout.write(
struct.pack("iii", len(shape), len(bkey), _NUMPY_TYPE_TO_FTYPE[dtype.name])
)
fout.write(struct.pack("i" * len(shape), *shape))
fout.write(bkey)
# Aligns to 32 bytes
fout.seek((fout.tell() + 31) & -32)
# TODO: need to add more details in key name to indicate the network, layer number, etc.
def _translate_mlp_key(key: str) -> str:
match = re.match(r"^(fc\d+).weight$", key)
if not match or len(match.groups()) != 1:
raise ValueError(f"Unexpected key: {key}")
return f"{match.group(1)}.weight.mlp"
def append_mlp_model(fout: BinaryIO, model: ReluMLP) -> None:
model_dict = model.state_dict()
for k, v in model_dict.items():
key = _translate_mlp_key(k)
# torch.nn.Linear stores the weight matrix as (output_dim, input_dim), so does GGML.
weights = v.half().detach().numpy()
# GGML stores the weight matrix as (input_dim, output_dim)
dims = weights.shape[::-1]
print(
f"{k} => {key} {weights.shape} {weights.dtype} {weights.nbytes/1024/1024} MiB"
)
# TODO: add option to write in float32
write_tensor_header(fout, key, dims, np.dtype("float16"))
weights.tofile(fout)
def append_gpu_idx(fout: BinaryIO, activation, select_count) -> None:
values, indices = torch.topk(activation, k=int(select_count))
gpu_idx = torch.zeros_like(activation)
gpu_idx[indices] = 1
gpu_idx = gpu_idx.numpy().astype(np.int32)
weights = gpu_idx
dims = gpu_idx.shape[::-1]
key = "gpu_idx"
print(
f"{key} => {key} {weights.shape} {weights.dtype} {weights.nbytes/1024/1024} MiB"
)
write_tensor_header(fout, key, dims, np.dtype("int32"))
weights.tofile(fout)
indices = indices.numpy().astype(np.int32)
weights = indices
dims = weights.shape[::-1]
key = "gpu_bucket"
print(
f"{key} => {key} {weights.shape} {weights.dtype} {weights.nbytes/1024/1024} MiB"
)
write_tensor_header(fout, key, dims, np.dtype("int32"))
weights = np.sort(weights)
weights.tofile(fout)
def main(predictors_path: str, output_path: str, solver_path: str):
predictors = load_mlp_predictors(Path(predictors_path)) # predictor => activation acount
n_tensors = len(predictors) * 2 # gpu_idx and gpu_bucket
print(f"found {len(predictors)} MLP adapters with {n_tensors} tensors")
with open(solver_path, "rb") as f:
loaded_lst = pickle.load(f)
# print(f"check solver {loaded_lst}")
with open(output_path, "wb") as fout:
fout.truncate()
write_file_header(fout, n_tensors=n_tensors)
for i, activation in enumerate(predictors):
print(f"appending gpu idx layer-{i}")
append_gpu_idx(fout, activation, loaded_lst[i])
# append_gpu_idx(fout, activation, (32768*0.0))
print(f"converted MLP adapters from {predictors_path} to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("predictors_path", help="path to the MLP predictors")
parser.add_argument(
"output_path",
help="path to the output GGML adapter",
default="./gpu-index.bin",
)
parser.add_argument("solver", help="path to the solver")
args = parser.parse_args()
main(args.predictors_path, args.output_path, args.solver)

106
solver.py
View file

@ -1,106 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
import argparse
from cvxopt.glpk import ilp
import numpy as np
from cvxopt import matrix
import torch
import pickle
# Set up command line arguments
parser = argparse.ArgumentParser(description='Optimize neuron activation based on VRAM capacity and other parameters.')
parser.add_argument('--activation_path', type=str, required=True, help='Path to the directory containing activation data.')
parser.add_argument('--neuron', type=int, default=8192*4, help='Total number of neurons in the network.')
parser.add_argument('--capacity', type=int, default=int(8192*4*32*0.1), help='Total VRAM capacity for the model.')
parser.add_argument('--layer', type=int, default=59, help='Total number of layers in the neural network.')
parser.add_argument('--batch', type=int, default=32, help='Batch size for processing.')
parser.add_argument('--threshold', type=int, default=512, help='Threshold for splitting a layer across multiple GPUs.')
parser.add_argument('--output', type=str, required=True, help='File path for the output pickle file.')
args = parser.parse_args()
# Assigning command line arguments to variables
activation_path = args.activation_path
neuron = args.neuron
layer = args.layer
batch = args.batch
output_path = args.output
# Processing activation data
values = []
for i in range(layer):
# Load and sort activation data for each layer
freq = torch.load(f"{activation_path}/activation_{i}.pt")
freq, _ = torch.sort(freq, descending=True)
freq = freq * -1.0
freq = freq.view(-1, batch)
freq = freq.sum(dim=1)
freq = freq.tolist()
values += freq
# Padding zero values for additional constraints
for i in range(layer):
values += [0.0]
c = np.array(values, dtype=float)
c = matrix(c)
# Setting capacity and neuron count per batch
CAP = args.capacity
CAP = int(CAP / batch)
neuron = int(neuron / batch)
coeff = []
h = []
# Constraint 1: Total neuron activation constraint
lst = []
for i in range(neuron * layer):
lst.append(1)
for i in range(layer):
lst.append(0)
coeff.append(lst)
h.append(CAP)
# Constraint 2: Threshold constraint for GPU split per layer
for i in range(layer):
lst = [0] * (neuron * layer + layer)
for j in range(neuron):
lst[i * neuron + j] = -1
lst[neuron * layer + i] = int(args.threshold / batch)
coeff.append(lst)
h.append(0)
# Constraint 3: Upper bound on neuron activations
for i in range(layer):
lst = [0] * (neuron * layer + layer)
for j in range(neuron):
lst[i * neuron + j] = 1
lst[neuron * layer + i] = -1000000 # Arbitrary large negative number as an upper bound
coeff.append(lst)
h.append(0)
# Convert lists to matrix format for ILP solver
coeff = np.array(coeff, dtype=float)
G = matrix(coeff)
h = np.array(h, dtype=float)
h = matrix(h)
# Define the set of integer and binary variables
I = set(range(neuron * layer + layer))
B = set()
# Solving the ILP problem
(status, x) = ilp(c, G, h, None, None, B, I)
print(f"ILP Status: {status}")
ans = list(x)
print(f"Total Activation Units: {sum(ans)}")
# Serialize the solution
serialize = []
for i in range(layer):
serialize.append(sum(ans[i * neuron:i * neuron + neuron] * batch))
aligned_lst = serialize
# Save the solution to a pickle file
with open(output_path, 'wb') as handle:
pickle.dump(aligned_lst, handle)