This commit is contained in:
Branden Butler 2024-03-22 17:42:43 +03:00 committed by GitHub
commit a3dbd17c58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1755 additions and 312 deletions

View file

@ -431,7 +431,7 @@ if (LLAMA_MPI)
message(STATUS "MPI found")
set(GGML_HEADERS_MPI ggml-mpi.h)
set(GGML_SOURCES_MPI ggml-mpi.c)
set(GGML_SOURCES_MPI ggml-mpi.cpp)
add_compile_definitions(GGML_USE_MPI)
add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS})

View file

@ -577,8 +577,8 @@ endif
endif # LLAMA_METAL
ifdef LLAMA_MPI
ggml-mpi.o: ggml-mpi.c ggml-mpi.h
$(CC) $(CFLAGS) -c $< -o $@
ggml-mpi.o: ggml-mpi.cpp ggml-mpi.h
$(CXX) $(CXXFLAGS) -c $< -o $@
endif # LLAMA_MPI
GF_CC := $(CC)

View file

@ -170,20 +170,39 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
invalid_param = true;
return true;
}
params.n_threads = std::stoi(argv[i]);
if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency();
std::string arg_next = argv[i];
// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads[node] = std::stoi(split_arg[node]);
if (params.n_threads[node] <= 0) {
params.n_threads[node] = std::thread::hardware_concurrency();
}
}
return true;
}
if (arg == "-tb" || arg == "--threads-batch") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.n_threads_batch = std::stoi(argv[i]);
if (params.n_threads_batch <= 0) {
params.n_threads_batch = std::thread::hardware_concurrency();
std::string arg_next = argv[i];
// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads_batch.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads_batch[node] = std::stoi(split_arg[node]);
if (params.n_threads_batch[node] <= 0) {
params.n_threads_batch[node] = std::thread::hardware_concurrency();
}
}
return true;
}
@ -192,9 +211,18 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
invalid_param = true;
return true;
}
params.n_threads_draft = std::stoi(argv[i]);
if (params.n_threads_draft <= 0) {
params.n_threads_draft = std::thread::hardware_concurrency();
std::string arg_next = argv[i];
// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads_draft.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads_draft[node] = std::stoi(split_arg[node]);
if (params.n_threads_draft[node] <= 0) {
params.n_threads_draft[node] = std::thread::hardware_concurrency();
}
}
return true;
}
@ -203,9 +231,18 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
invalid_param = true;
return true;
}
params.n_threads_batch_draft = std::stoi(argv[i]);
if (params.n_threads_batch_draft <= 0) {
params.n_threads_batch_draft = std::thread::hardware_concurrency();
std::string arg_next = argv[i];
// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads_batch_draft.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads_batch_draft[node] = std::stoi(split_arg[node]);
if (params.n_threads_batch_draft[node] <= 0) {
params.n_threads_batch_draft[node] = std::thread::hardware_concurrency();
}
}
return true;
}
@ -891,6 +928,24 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
#endif // GGML_USE_CUBLAS_SYCL
return true;
}
if (arg == "--mpi-layer-split") {
if (++i >= argc) {
invalid_param = true;
return true;
}
std::string arg_next = argv[i];
// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.mpi_layer_split.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) {
params.mpi_layer_split[node] = std::stof(split_arg[node]);
}
return true;
}
if (arg == "--tensor-split" || arg == "-ts") {
if (++i >= argc) {
invalid_param = true;
@ -1275,7 +1330,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" (can be specified more than once for multiple prompts).\n");
printf(" --color colorise output to distinguish prompt and user input from generations\n");
printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads[0]);
printf(" -tb N, --threads-batch N\n");
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -td N, --threads-draft N");
@ -1393,6 +1448,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
}
#ifdef GGML_USE_MPI
printf(" --mpi-layer-split N percentiles to split the layers by across nodes\n");
#endif
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
printf(" -gan N, --grp-attn-n N\n");
@ -1443,9 +1501,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
std::string get_system_info(const gpt_params & params) {
std::ostringstream os;
os << "system_info: n_threads = " << params.n_threads;
if (params.n_threads_batch != -1) {
os << " (n_threads_batch = " << params.n_threads_batch << ")";
os << "system_info: n_threads = " << params.n_threads[0];
if (params.n_threads_batch[0] != -1) {
os << " (n_threads_batch = " << params.n_threads_batch[0] << ")";
}
os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
@ -1590,6 +1648,10 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.kv_overrides = params.kv_overrides.data();
}
free((void *) mparams.node_layer_weights);
mparams.node_layer_weights = params.mpi_layer_split.data();
return mparams;
}
@ -1629,8 +1691,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_seq_max = params.n_parallel;
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.n_threads = params.n_threads[0];
cparams.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0];
cparams.seed = params.seed;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
@ -1916,6 +1978,7 @@ struct llama_model * llama_load_model_from_hf(
#endif // LLAMA_USE_CURL
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
int32_t n_threads = params.n_threads[0];
auto mparams = llama_model_params_from_gpt_params(params);
llama_model * model = nullptr;
@ -1942,6 +2005,16 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
return std::make_tuple(nullptr, nullptr);
}
#ifdef GGML_USE_MPI
int node_id = llama_node_id(lctx);
n_threads = (node_id >= params.n_threads.size()) ? get_num_physical_cores() : params.n_threads[node_id];
int32_t n_threads_batch = (node_id >= params.n_threads_batch.size()) ? -1 : params.n_threads_batch[node_id];
params.n_threads[0] = n_threads; // So we can treat index 0 as what our n_threads is elsewhere
params.n_threads_batch[0] = n_threads_batch;
llama_set_n_threads(lctx, n_threads, (n_threads_batch > 0) ? n_threads_batch : get_num_physical_cores());
#endif
if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
@ -1975,7 +2048,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
((i > 0) || params.lora_base.empty())
? NULL
: params.lora_base.c_str(),
params.n_threads);
n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
llama_free(lctx);
@ -1991,10 +2064,17 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
{
LOG("warming up the model with an empty run\n");
#ifndef GGML_USE_MPI
// When using MPI, llama_decode() enters into an infinite loop
// on non-head nodes. Thus, we only want to warmup the model here
// if we aren't using MPI.
// FIXME have a way to terminate the infinite loop so we can warmup the model
// in MPI mode
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
#endif
llama_reset_timings(lctx);
}
@ -2385,7 +2465,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector);
fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
fprintf(stream, "threads: %d # default: %u\n", params.n_threads, std::thread::hardware_concurrency());
fprintf(stream, "threads: %d # default: %u\n", params.n_threads[0], std::thread::hardware_concurrency());
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);

View file

@ -47,11 +47,10 @@ int32_t get_num_physical_cores();
struct gpt_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
int32_t n_threads = get_num_physical_cores();
int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_threads_batch_draft = -1;
std::vector<int32_t> n_threads = {get_num_physical_cores()};
std::vector<int32_t> n_threads_batch = {-1}; // number of threads to use for batch processing (-1 = use n_threads)
std::vector<int32_t> n_threads_draft = {get_num_physical_cores()};
std::vector<int32_t> n_threads_batch_draft = {-1}; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
@ -65,6 +64,7 @@ struct gpt_params {
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
std::vector<float> mpi_layer_split = {1.0}; // list of percentages of the total number of layers
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.

View file

@ -102,8 +102,8 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = 512;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.n_threads = params.n_threads[0];
ctx_params.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0];
// ensure enough sequences are available
ctx_params.n_seq_max = *std::max_element(n_pl.begin(), n_pl.end());

View file

@ -83,8 +83,8 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel);
ctx_params.n_seq_max = n_parallel;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.n_threads = params.n_threads[0];
ctx_params.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0];
llama_context * ctx = llama_new_context_with_model(model, ctx_params);

View file

@ -125,14 +125,14 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
if (!params->image.empty()) {
fprintf(stderr, "using base64 encoded image instead of command line image path\n");
}
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads, prompt);
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads[0], prompt);
if (!embed) {
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
return NULL;
}
params->prompt = remove_image_from_prompt(prompt);
} else {
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str());
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads[0], params->image.c_str());
if (!embed) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
return NULL;

View file

@ -207,6 +207,8 @@ int main(int argc, char ** argv) {
return 1;
}
llama_split_layers_weighted(ctx, params.mpi_layer_split.data(), params.mpi_layer_split.size());
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);

View file

@ -94,8 +94,8 @@ int main(int argc, char ** argv) {
ctx_params.seed = seed;
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
ctx_params.n_batch = 512;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.n_threads = params.n_threads[0];
ctx_params.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0];
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");

View file

@ -2421,7 +2421,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true;
break;
}
params.n_threads = std::stoi(argv[i]);
params.n_threads[0] = std::stoi(argv[i]);
} else if (arg == "--grp-attn-n" || arg == "-gan") {
if (++i >= argc) {
invalid_param = true;
@ -2441,7 +2441,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true;
break;
}
params.n_threads_batch = std::stoi(argv[i]);
params.n_threads_batch[0] = std::stoi(argv[i]);
} else if (arg == "--threads-http") {
if (++i >= argc) {
invalid_param = true;

View file

@ -53,8 +53,8 @@ int main(int argc, char ** argv) {
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.n_threads = params.n_threads[0];
ctx_params.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0];
llama_context * ctx = llama_new_context_with_model(model, ctx_params);

View file

@ -71,7 +71,7 @@ int main(int argc, char ** argv) {
// load the draft model
params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft;
if (params.n_threads_draft > 0) {
if (params.n_threads_draft.size() > 0) {
params.n_threads = params.n_threads_draft;
}
params.n_threads_batch = params.n_threads_batch_draft;

View file

@ -1719,7 +1719,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
bool parallel) {
GGML_ASSERT(n_backends > 0);
GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU
GGML_ASSERT(ggml_backend_buft_is_host(ggml_backend_get_default_buffer_type(backends[n_backends - 1]))); // last backend must be host
struct ggml_backend_sched * sched = calloc(sizeof(struct ggml_backend_sched), 1);

View file

@ -1,216 +0,0 @@
#include "ggml-mpi.h"
#include "ggml.h"
#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define UNUSED GGML_UNUSED
struct ggml_mpi_context {
int rank;
int size;
};
void ggml_mpi_backend_init(void) {
MPI_Init(NULL, NULL);
}
void ggml_mpi_backend_free(void) {
MPI_Finalize();
}
struct ggml_mpi_context * ggml_mpi_init(void) {
struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context));
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
MPI_Comm_size(MPI_COMM_WORLD, &ctx->size);
return ctx;
}
void ggml_mpi_free(struct ggml_mpi_context * ctx) {
free(ctx);
}
int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
return ctx->rank;
}
void ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
int * n_tokens,
int * n_past,
int * n_threads) {
UNUSED(ctx_mpi);
// synchronize the worker node parameters with the root node
MPI_Barrier(MPI_COMM_WORLD);
MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
}
static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) {
struct ggml_tensor * t = ggml_graph_get_tensor(gf, name);
if (t == NULL) {
fprintf(stderr, "%s: tensor %s not found\n", __func__, name);
return -1;
}
for (int i = 0; i < gf->n_nodes; i++) {
if (gf->nodes[i] == t) {
return i;
}
}
fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name);
return -1;
}
static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst) {
MPI_Datatype mpi_type;
switch (t->type) {
case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break;
case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break;
default: GGML_ASSERT(false && "not implemented");
}
const int retval = MPI_Send(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, MPI_COMM_WORLD);
GGML_ASSERT(retval == MPI_SUCCESS);
}
static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) {
MPI_Datatype mpi_type;
switch (t->type) {
case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break;
case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break;
default: GGML_ASSERT(false && "not implemented");
}
MPI_Status status; UNUSED(status);
const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
GGML_ASSERT(retval == MPI_SUCCESS);
}
// TODO: there are many improvements that can be done to this implementation
void ggml_mpi_graph_compute_pre(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf,
int n_layers) {
const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size;
struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens");
if (inp_tokens == NULL) {
fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__);
return;
}
struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0");
if (inp0 == NULL) {
fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__);
return;
}
GGML_ASSERT(inp0 == gf->nodes[0]);
// distribute the compute graph into slices across the MPI nodes
//
// the main node (0) processes the last layers + the remainder of the compute graph
// and is responsible to pass the input tokens to the first node (1)
//
// node 1: [( 0) * n_per_node, ( 1) * n_per_node)
// node 2: [( 1) * n_per_node, ( 2) * n_per_node)
// ...
// node n-1: [(n-2) * n_per_node, (n-1) * n_per_node)
// node 0: [(n-1) * n_per_node, n_nodes)
//
if (mpi_rank > 0) {
if (mpi_rank == 1) {
// the first node (1) receives the input tokens from the main node (0)
ggml_mpi_tensor_recv(inp_tokens, 0);
} else {
// recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph)
ggml_mpi_tensor_recv(inp0, mpi_rank - 1);
}
} else if (mpi_size > 1) {
// node 0 sends the input tokens to node 1
ggml_mpi_tensor_send(inp_tokens, 1);
// recv the output data from the last node
ggml_mpi_tensor_recv(inp0, mpi_size - 1);
}
{
const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size;
const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1;
const int il0 = (mpi_idx + 0) * n_per_node;
const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node);
char name_l0[GGML_MAX_NAME];
char name_l1[GGML_MAX_NAME];
snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0);
snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1);
const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0);
const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes;
if (idx_l0 < 0 || idx_l1 < 0) {
fprintf(stderr, "%s: layer input nodes not found\n", __func__);
return;
}
// attach the input data to all nodes that need it
// TODO: not great - should be able to do this without modifying the compute graph (see next TODO below)
for (int i = idx_l0; i < idx_l1; i++) {
if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) {
gf->nodes[i]->src[0] = inp0;
}
if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) {
gf->nodes[i]->src[1] = inp0;
}
}
// TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
for (int i = 1; i < idx_l1 - idx_l0; i++) {
gf->nodes[i] = gf->nodes[idx_l0 + i];
gf->grads[i] = gf->grads[idx_l0 + i];
}
// the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node
if (mpi_idx != 0) {
gf->nodes[0]->op = GGML_OP_NONE;
}
gf->n_nodes = idx_l1 - idx_l0;
//fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1);
}
}
void ggml_mpi_graph_compute_post(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf,
int n_layers) {
UNUSED(n_layers);
const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size;
// send the output data to the next node
if (mpi_rank > 0) {
ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size);
}
}

1044
ggml-mpi.cpp Normal file

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,8 @@
#pragma once
#include <stdint.h>
#include <stddef.h>
#include "ggml.h"
#include "ggml-backend.h"
struct ggml_context;
struct ggml_tensor;
@ -8,31 +12,238 @@ struct ggml_cgraph;
extern "C" {
#endif
#define GGML_MPI_DECODE 0
#define GGML_MPI_KV_CLEAR 1
#define GGML_MPI_KV_SEQ_RM 2
#define GGML_MPI_KV_SEQ_CP 3
#define GGML_MPI_KV_SEQ_KEEP 4
#define GGML_MPI_KV_SEQ_ADD 5
#define GGML_MPI_SHUTDOWN 6
#define GGML_MPI_TRANSFER_TENSORS 7
#define GGML_MPI_SYNC_LOGITS 8
#define GGML_MPI_CANCEL_RUN 9
#define GGML_MPI_KV_SEQ_CP_BACK 10
#define GGML_MPI_TRANS_ID 11
#define GGML_MPI_BATCH_ID 12
#define GGML_MPI_N_TOKENS 13
#define GGML_MPI_TOKENS 14
#define GGML_MPI_N_SEQ_IDS 15
#define GGML_MPI_SEQ_IDS 16
#define GGML_MPI_POS 17
#define GGML_MPI_BEGIN_TRANSACTION 18
#define GGML_MPI_MAX_N_SEQ 19
#define GGML_MPI_BATCH_LOGITS 20
#define GGML_MPI_KV_SEQ_DIV 21
/**
* The context used for MPI operations,
* a program may make use of more than one
* context but must always have at least one.
*
* The context stores required information like the
* node rank and a communicator to use for MPI operations.
* A context is guaranteed to be internally consistent,
* meaning that a context's stored rank is valid within
* the context's communicator.
*/
struct ggml_mpi_context;
/**
* Initialize the MPI library and the GGML MPI backend.
* Calling more than once during the lifetime of the program
* leads to undefined behavior. This function must be called before
* any MPI operations.
*/
void ggml_mpi_backend_init(void);
/**
* Frees the MPI backend, must be called only once at termination
* of the program. No MPI operations may be completed after calling this function,
* and attempting to do so will lead to undefined behavior.
*/
void ggml_mpi_backend_free(void);
/**
* Construct a new MPI context using the MPI_WORLD
* communicator. This is useful only to create the
* initial context, as calling multiple times
* will only create effective copies of the same data.
*
* @return A context for us in the global communicator.
*/
struct ggml_mpi_context * ggml_mpi_init(void);
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(ggml_backend_buffer_type_t buft);
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf);
void ggml_mpi_sync_ints_pipelined(
struct ggml_mpi_context * ctx_mpi,
int32_t * vals,
int count,
int tag
);
void ggml_mpi_sync_ints_pipelined_back(
struct ggml_mpi_context * ctx_mpi,
int32_t * vals,
int count,
int tag
);
// clear = 1, rm = 2, cp = 3, keep = 4, seq_shift = 5
void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag);
int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi);
int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag);
int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi);
/**
* Create a new context by splitting the given context's
* communicator, creating a "sub-communicator." This is a collective
* operation and must be performed by all nodes within the same communicator.
* The color and key have the same meaning as in MPI_Comm_split(), i.e.
* the color is used to determine the sub-communicator this node will belong to,
* and the key is the relative rank of this node in the new communicator.
*
* An example: if a node passes a color of 1, and a different node passes a color of 2,
* the nodes will belong to two different sub-communicators. If two nodes pass the same
* color, then their ranks will be ordered by the order of their keys. If they pass the same
* key, then the tie will be broken by the nodes' ranks in the old communicator.
*
* The communicator used by the given context remains entirely valid, so it is advisable
* to store both the old and new contexts. This allows an application to
* select at runtime which communicator to perform MPI operations with. An example
* would be to segregate the nodes into multiple domains categorized by the functions
* they perform, and use the original context to broadcast to all nodes in the cluster.
*
* @param ctx The context containing the communicator to split.
* @param color The sub-communicator that this node will belong to.
* @param key The relative rank of this node in the new communicator.
* @return A new context with all values referencing the newly-created communicator.
*/
struct ggml_mpi_context * ggml_mpi_split_comm(struct ggml_mpi_context * ctx, int color, int key);
/**
* Frees the given context, including the communicator. No MPI
* operations besides ggml_mpi_backend_freee(void) should be executed after
* running this function.
*
* @param ctx The context to free.
*/
void ggml_mpi_free(struct ggml_mpi_context * ctx);
/**
* Get the rank of this node in the given context's communicator.
*
* @param ctx The context to use to determine the rank with regards to.
* @return The rank of this node.
*/
int ggml_mpi_rank(struct ggml_mpi_context * ctx);
/**
* Get the number of nodes that are a part of
* the communicator referenced by the given context.
*
* @param ctx The context containing the communicator used for this size check.
* @return The number of nodes that are a part of the given context's communicator.
*/
size_t ggml_mpi_size(struct ggml_mpi_context * ctx);
/**
* Synchronize needed information among the nodes
* to prepare for running an evaluation iteration.
* This is a collective operation and all nodes must
* call this function. It will block until all
* nodes have entered it, to prevent any desync
* between nodes.
*
* @param ctx_mpi The context in which to prepare for evaluation.
* @param n_tokens A pointer to the n_tokens, which will be synchronized after this function.
* @param pos A pointer to the pos array, which will be synchronized after this function.
* @param n_seq_ids A pointer to the n_seq_ids array, which will be synchronized after this function.
* @param seq_id A pointer to the seq_id 2D array, which will be synchronized after this function.
* @param logits A pointer to the logits array, which is unused currently since only node 0 needs them.
*/
void ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
int * n_tokens,
int * n_past,
int * n_threads);
struct ggml_mpi_context * ctx_mpi,
int32_t * n_tokens,
int32_t ** pos,
int32_t ** n_seq_ids,
int32_t *** seq_id,
int8_t ** logits,
uint32_t n_seq_max);
void ggml_mpi_graph_compute_pre(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf,
int n_layers);
void ggml_mpi_sync_int(
struct ggml_mpi_context * ctx_mpi,
int32_t * val
);
void ggml_mpi_graph_compute_post(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf,
int n_layers);
/**
* Split a range across all nodes within the given
* context, weighting the allocations by the given weights.
* The dimensions of the returned 2d array are (number of nodes in the context, 2).
* The first element in the inner array is the starting point of the range allocated
* to the node indicated by the index into the outer array,
* and the second element is the end point of the allocated range, inclusive.
*
* @param ctx_mpi The context used to determine the number of nodes
* to split the range across.
* @param start The starting point of the range.
* @param end The end point of the range, inclusive.
* @param node_weights How to weight the allocations across the nodes,
* must sum to 1.0.
* @return A 2d array, the first dimension is the number of nodes in the context
* and the second dimension is 2.
*/
uint16_t** ggml_mpi_split_range(
struct ggml_mpi_context * ctx_mpi,
uint16_t start,
uint16_t end,
const float node_weights[]
);
// BACKEND V2
struct ggml_mpi_device {
int index;
struct ggml_mpi_context * ctx_mpi;
const char * name;
int subgroupSize;
};
#define MPI_BACKEND_NAME "MPI"
GGML_CALL int ggml_backend_mpi_reg_devices();
GGML_CALL ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t num_backends, int rank);
GGML_CALL void ggml_backend_mpi_buffer_type_set_rank(ggml_backend_buffer_type_t buft, int rank);
GGML_CALL void ggml_backend_mpi_buffer_set_rank(ggml_backend_buffer_t buft, int rank);
#ifdef __cplusplus
}

3
ggml.h
View file

@ -226,7 +226,7 @@
#define GGML_MAX_DIMS 4
#define GGML_MAX_PARAMS 2048
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_CONTEXTS 256
#define GGML_MAX_SRC 10
#ifndef GGML_MAX_NAME
#define GGML_MAX_NAME 64
@ -381,6 +381,7 @@ extern "C" {
GGML_BACKEND_TYPE_CPU = 0,
GGML_BACKEND_TYPE_GPU = 10,
GGML_BACKEND_TYPE_GPU_SPLIT = 20,
GGML_BACKEND_TYPE_MPI_SPLIT = 30,
};
// model file types

394
llama.cpp
View file

@ -1118,6 +1118,10 @@ struct llama_mmap {
int flags = MAP_SHARED;
// prefetch/readahead impairs performance on NUMA systems
if (numa) { prefetch = 0; }
#ifdef GGML_USE_MPI
prefetch = 0;
#endif
#ifdef __linux__
// advise the kernel to read the file sequentially (increases readahead)
if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
@ -1126,6 +1130,7 @@ struct llama_mmap {
}
if (prefetch) { flags |= MAP_POPULATE; }
#endif
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
if (addr == MAP_FAILED) { // NOLINT
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
@ -1487,6 +1492,10 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer
if (buft == nullptr) {
buft = ggml_backend_cpu_buffer_type();
}
#if defined(GGML_USE_MPI)
buft = ggml_backend_mpi_wrap_buffer_type(buft);
#endif
return buft;
GGML_UNUSED(host_buffer);
@ -1538,6 +1547,8 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_g
if (buft == nullptr) {
buft = llama_default_buffer_type_offload(fallback_gpu);
}
return buft;
GGML_UNUSED(tensor_split);
@ -1761,6 +1772,7 @@ struct llama_cparams {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
uint32_t n_seq_max;
};
struct llama_layer {
@ -2036,6 +2048,10 @@ struct llama_model {
int64_t t_load_us = 0;
int64_t t_start_us = 0;
#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = nullptr;
#endif
~llama_model() {
for (struct ggml_context * ctx : ctxs) {
ggml_free(ctx);
@ -2139,12 +2155,9 @@ struct llama_context {
struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
// control vectors
struct llama_control_vector cvec;
#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL;
#endif
};
//
@ -2218,7 +2231,7 @@ static bool llama_kv_cache_init(
};
ggml_context * ctx = ggml_init(params);
if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to allocate context for kv cache\n", __func__);
LLAMA_LOG_ERROR("%s: failed to allocate context for kv cache, n_layers=%d\n", __func__, n_layers);
return false;
}
ctx_map[it.first] = ctx;
@ -3321,6 +3334,11 @@ static void llm_load_hparams(
auto & hparams = model.hparams;
const gguf_context * ctx = ml.ctx_gguf;
#ifdef GGML_USE_MPI
model.ctx_mpi = ggml_mpi_init();
#endif
// get metadata as string
for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
enum gguf_type type = gguf_get_kv_type(ctx, i);
@ -4062,6 +4080,7 @@ static bool llm_load_tensors(
enum llama_split_mode split_mode,
int main_gpu,
const float * tensor_split,
const float * node_split,
bool use_mlock,
llama_progress_callback progress_callback,
void * progress_callback_user_data) {
@ -4150,6 +4169,32 @@ static bool llm_load_tensors(
}
}
#ifdef GGML_USE_MPI
uint16_t** ranges = ggml_mpi_split_range(model.ctx_mpi, 0, n_layer - 1, node_split);
size_t size = ggml_mpi_size(model.ctx_mpi);
for (size_t i = 0; i < size; i++) {
for (uint16_t j = ranges[i][0]; j < ranges[i][1]; j++) {
printf("Setting buffer rank for i %zu and j %d\n", i, j);
ggml_backend_mpi_buffer_type_set_rank(model.buft_layer[j].buft, (int)i);
ggml_backend_mpi_buffer_type_set_rank(model.buft_layer[j].buft_matrix, (int)i);
}
}
// Will run with inputs on other nodes, but output may not be correct.
// Default is node 0 anyway, but better to be explicit about it
ggml_backend_mpi_buffer_type_set_rank(model.buft_input.buft, 0);
ggml_backend_mpi_buffer_type_set_rank(model.buft_input.buft_matrix, 0);
// Outputs *must* be on node 0, otherwise a deadlock occurs
ggml_backend_mpi_buffer_type_set_rank(model.buft_output.buft, 0);
ggml_backend_mpi_buffer_type_set_rank(model.buft_output.buft_matrix, 0);
#endif
// count used buffer types
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
buft_layer_count[model.buft_input.buft]++;
@ -5041,6 +5086,7 @@ static bool llm_load_tensors(
size_t first, last;
ml.get_mapping_range(&first, &last, ctx);
buf = ggml_backend_cpu_buffer_from_ptr((char *) ml.mapping->addr + first, last - first);
#ifdef GGML_USE_CUBLAS
if (n_layer >= n_gpu_layers) {
ggml_backend_cuda_register_host_buffer(
@ -5066,9 +5112,14 @@ static bool llm_load_tensors(
mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
}
}
if (buf == nullptr) {
throw std::runtime_error("failed to allocate buffer");
}
#ifdef GGML_USE_MPI
buf = ggml_backend_mpi_wrap_buffer(buf);
#endif
// indicate that this buffer contains weights
// this is used by ggml_backend_sched to improve op scheduling -> ops that use a weight are preferably scheduled to the backend that contains the weight
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
@ -5181,7 +5232,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
#endif
if (!llm_load_tensors(
ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock,
ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.node_layer_weights, params.use_mlock,
params.progress_callback, params.progress_callback_user_data
)) {
return -2;
@ -5840,6 +5891,7 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); //MPI
struct ggml_tensor * inpSA = inpL;
// norm
@ -6024,6 +6076,7 @@ struct llm_build_context {
struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); //MPI
struct ggml_tensor * inpSA = inpL;
cur = llm_build_norm(ctx0, inpL, hparams,
@ -6132,6 +6185,7 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); //MPI
struct ggml_tensor * attn_norm;
attn_norm = llm_build_norm(ctx0, inpL, hparams,
@ -6250,6 +6304,7 @@ struct llm_build_context {
cb(inpL, "inpL", -1);
for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); //MPI
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm,
model.layers[il].attn_norm_b,
@ -6337,6 +6392,7 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); //MPI
struct ggml_tensor * residual = inpL;
cur = llm_build_norm(ctx0, inpL, hparams,
@ -6536,6 +6592,7 @@ struct llm_build_context {
struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); //MPI
struct ggml_tensor * inpSA = inpL;
cur = llm_build_norm(ctx0, inpL, hparams,
@ -6815,6 +6872,7 @@ struct llm_build_context {
cb(inpL, "inp_norm", -1);
for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); //MPI
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm,
model.layers[il].attn_norm_b,
@ -6902,6 +6960,7 @@ struct llm_build_context {
struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); //MPI
struct ggml_tensor * attn_norm;
attn_norm = llm_build_norm(ctx0, inpL, hparams,
@ -8975,10 +9034,7 @@ static void llama_graph_compute(
llama_context & lctx,
ggml_cgraph * gf,
int n_threads) {
#ifdef GGML_USE_MPI
const int64_t n_layer = lctx.model.hparams.n_layer;
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
#endif
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(lctx.backend_metal)) {
@ -8995,9 +9051,6 @@ static void llama_graph_compute(
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
#ifdef GGML_USE_MPI
ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
#endif
}
// decode a batch of tokens by evaluating the transformer
@ -9011,9 +9064,62 @@ static void llama_graph_compute(
//
static int llama_decode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch
llama_batch & batch_all) { // TODO: rename back to batch
const uint32_t n_tokens_all = batch_all.n_tokens;
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(lctx.model.ctx_mpi) == 0 && ggml_mpi_size(lctx.model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_DECODE;
ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
// ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.batch_id, 1, GGML_MPI_BATCH_ID);
int old_tokens = batch_all.n_tokens;
ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.n_tokens, 1, GGML_MPI_N_TOKENS);
ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, reinterpret_cast<int32_t *>(&lctx.cparams.n_seq_max), 1, GGML_MPI_MAX_N_SEQ);
if (ggml_mpi_rank(lctx.model.ctx_mpi) > 0) {
int new_n_tokens = batch_all.n_tokens;
llama_batch_free(batch_all);
batch_all = llama_batch_init(new_n_tokens, 0, (int32_t)lctx.cparams.n_seq_max);
}
#endif
uint32_t n_tokens_all = batch_all.n_tokens;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
if (batch_all.pos == nullptr) {
pos.resize(n_tokens_all);
for (uint32_t i = 0; i < n_tokens_all; i++) {
pos[i] = batch_all.all_pos_0 + i*batch_all.all_pos_1;
}
batch_all.pos = pos.data();
}
if (batch_all.seq_id == nullptr) {
n_seq_id.resize(n_tokens_all);
seq_id.resize(n_tokens_all);
seq_id_arr.resize(n_tokens_all);
for (uint32_t i = 0; i < n_tokens_all; i++) {
n_seq_id[i] = 1;
seq_id[i].resize(lctx.cparams.n_seq_max);
seq_id[i][0] = batch_all.all_seq_id;
seq_id_arr[i] = seq_id[i].data();
}
batch_all.n_seq_id = n_seq_id.data();
batch_all.seq_id = seq_id_arr.data();
}
#ifdef GGML_USE_MPI
ggml_mpi_eval_init(lctx.model.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max);
n_tokens_all = batch_all.n_tokens;
#endif
if (n_tokens_all == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
@ -9035,11 +9141,7 @@ static int llama_decode_internal(
}
lctx.n_queued_tokens += n_tokens_all;
#ifdef GGML_USE_MPI
// TODO: needs fix after #3228
GGML_ASSERT(false && "not implemented");
//ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
#endif
auto & kv_self = lctx.kv_self;
@ -9059,13 +9161,9 @@ static int llama_decode_internal(
const auto n_ubatch = cparams.n_ubatch;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
llama_batch u_batch = {
/* .n_tokens = */ (int32_t) n_tokens,
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
@ -9099,7 +9197,7 @@ static int llama_decode_internal(
seq_id_arr.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
n_seq_id[i] = 1;
seq_id[i].resize(1);
seq_id[i].resize(lctx.cparams.n_seq_max);
seq_id[i][0] = u_batch.all_seq_id;
seq_id_arr[i] = seq_id[i].data();
}
@ -9200,11 +9298,17 @@ static int llama_decode_internal(
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(lctx.model.ctx_mpi) == 0) {
#endif
// extract logits
// TODO: do not compute and extract logits if only embeddings are needed
// update the graphs to skip "result_output" if logits are not needed
if (res) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
if (u_batch.logits) {
int32_t i_first = -1;
@ -9288,6 +9392,10 @@ static int llama_decode_internal(
} break;
}
}
#ifdef GGML_USE_MPI
}
#endif
}
// wait for the computation to finish (automatically done when obtaining the model output)
@ -9305,6 +9413,8 @@ static int llama_decode_internal(
}
}
return 0;
}
@ -12824,6 +12934,7 @@ static int llama_apply_lora_from_file_internal(
//
struct llama_model_params llama_model_default_params() {
struct llama_model_params result = {
static_cast<float *>(calloc(1, sizeof(float))),
/*.n_gpu_layers =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
@ -12891,6 +13002,14 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
return result;
}
int llama_node_id(struct llama_context * ctx) {
#ifdef GGML_USE_MPI
return ggml_mpi_rank(ctx->model.ctx_mpi);
#endif
return 0;
}
size_t llama_max_devices(void) {
#if defined(GGML_USE_METAL)
return 1;
@ -12936,6 +13055,7 @@ void llama_backend_init(void) {
#ifdef GGML_USE_MPI
ggml_mpi_backend_init();
#endif
}
void llama_numa_init(enum ggml_numa_strategy numa) {
@ -13022,7 +13142,7 @@ struct llama_context * llama_new_context_with_model(
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
// TODO: maybe add n_seq_max here too
cparams.n_seq_max = params.n_seq_max;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -13192,6 +13312,8 @@ struct llama_context * llama_new_context_with_model(
ctx->backends.push_back(backend);
}
#endif
ctx->backend_cpu = ggml_backend_cpu_init();
if (ctx->backend_cpu == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
@ -13200,6 +13322,23 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(ctx->backend_cpu);
#ifdef GGML_USE_MPI
std::vector<ggml_backend_t> new_backends;
for (size_t i = 0; i < ggml_mpi_size(model->ctx_mpi); i++) {
new_backends.push_back(ggml_backend_mpi_init(ctx->backends.data(), ctx->backends.size(), (int) i));
}
ctx->backends = new_backends;
// ctx->backend_cpu = ctx->backends.back();
ctx->backends.push_back(ggml_backend_mpi_init(&ctx->backend_cpu, 1, ggml_mpi_rank(model->ctx_mpi)));
#endif
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
@ -13311,23 +13450,15 @@ struct llama_context * llama_new_context_with_model(
}
}
#ifdef GGML_USE_MPI
ctx->ctx_mpi = ggml_mpi_init();
if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
// TODO: needs fix after #3228
GGML_ASSERT(false && "not implemented");
//const std::vector<llama_token> tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx));
//while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {};
llama_backend_free();
exit(1);
}
#endif
return ctx;
}
void llama_split_layers_weighted(struct llama_context * ctx, float device_weights[], size_t num_weights) {
}
void llama_free(struct llama_context * ctx) {
delete ctx;
}
@ -13711,14 +13842,56 @@ int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
}
void llama_kv_cache_clear(struct llama_context * ctx) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_CLEAR;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, nullptr, 0, GGML_MPI_KV_CLEAR);
#endif
llama_kv_cache_clear(ctx->kv_self);
}
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_RM;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[3] = {seq_id, p0, p1};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 3, GGML_MPI_KV_SEQ_RM);
seq_id = vals[0];
p0 = vals[1];
p1 = vals[2];
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1);
// }
#endif
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_CP;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[4] = {seq_id_src, seq_id_dst, p0, p1};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_CP);
// if(ggml_mpi_recv_trans_id(ctx->model.ctx_mpi) < ggml_mpi_trans_id(ctx->model.ctx_mpi)) {
//// return;
// }
// ggml_mpi_inc_trans_id(ctx->model.ctx_mpi);
seq_id_src = vals[0];
seq_id_dst = vals[1];
p0 = vals[2];
p1 = vals[3];
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
// printf("\nCopying sequence %d to sequence %d from %d to %d\n", seq_id_src, seq_id_dst, p0, p1);
// }
#endif
if (seq_id_src == seq_id_dst) {
return;
}
@ -13726,10 +13899,35 @@ void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src,
}
void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_KEEP;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[1] = {seq_id};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 1, GGML_MPI_KV_SEQ_KEEP);
seq_id = vals[0];
#endif
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
}
void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_ADD;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[4] = {seq_id, p0, p1, delta};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_ADD);
seq_id = vals[0];
p0 = vals[1];
p1 = vals[2];
delta = vals[3];
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1);
// }
#endif
if (delta == 0) {
return;
}
@ -13738,6 +13936,22 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla
}
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_DIV;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[4] = {seq_id, p0, p1, d};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_DIV);
seq_id = vals[0];
p0 = vals[1];
p1 = vals[2];
d = vals[3];
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1);
// }
#endif
if (d == 1) {
return;
}
@ -14244,11 +14458,111 @@ void llama_batch_free(struct llama_batch batch) {
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
batch.token = nullptr;
batch.embd = nullptr;
batch.pos = nullptr;
batch.n_seq_id = nullptr;
batch.seq_id = nullptr;
batch.logits = nullptr;
}
#ifdef GGML_USE_MPI
int llama_process_mpi_transaction(
struct llama_context * ctx,
struct llama_batch & batch,
int tag) {
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) {
// printf("\nBeginning transaction type %d\n", tag);
// }
switch (tag) {
case GGML_MPI_DECODE:
// llama_batch_free(batch);
return llama_decode_internal(*ctx, batch);
break;
case GGML_MPI_KV_CLEAR:
llama_kv_cache_clear(ctx);
break;
case GGML_MPI_KV_SEQ_RM:
llama_kv_cache_seq_rm(ctx, 1, -1, -1);
break;
case GGML_MPI_KV_SEQ_CP:
llama_kv_cache_seq_cp(ctx, 0, 0, 0, 0);
break;
// case GGML_MPI_KV_SEQ_CP_BACK:
// llama_kv_cache_seq_cp_back(ctx, 0, 0, 0, 0);
// break;
case GGML_MPI_KV_SEQ_KEEP:
llama_kv_cache_seq_keep(ctx, 0);
break;
case GGML_MPI_KV_SEQ_ADD:
llama_kv_cache_seq_add(ctx, 0, 0, 0, 0);
break;
case GGML_MPI_KV_SEQ_DIV:
llama_kv_cache_seq_div(ctx, 0, 0, 0, 0);
break;
default:
printf("Unknown operation, exiting\n");
exit(1);
break;
}
return 0;
}
int llama_process_mpi_worker(
struct llama_context * ctx,
struct llama_batch & batch) {
ggml_mpi_probe(ctx->model.ctx_mpi, -1, -1);
int tag = ggml_mpi_status_tag(ctx->model.ctx_mpi);
int32_t count;
int32_t trans_type;
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) {
// printf("\nReceived command %d\n", tag);
// }
switch (tag) {
case GGML_MPI_BEGIN_TRANSACTION:
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &trans_type, 1, GGML_MPI_BEGIN_TRANSACTION);
return llama_process_mpi_transaction(ctx, batch, trans_type);
break;
case GGML_MPI_SHUTDOWN:
llama_free(ctx);
llama_backend_free();
exit(0);
break;
case GGML_MPI_CANCEL_RUN:
// count = ggml_mpi_status_count_int32(ctx->model.ctx_mpi);
//// printf("Received cancel run\n");
// {
// std::vector<int32_t> canceled(count, -1);
// llama_cancel_run(ctx, canceled.data(), canceled.size());
//
// }
// break;
default:
printf("Unknown operation, exiting\n");
exit(1);
break;
}
return 0;
}
#endif
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) > 0) {
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
while (llama_process_mpi_worker(ctx, batch) >= 0){};
llama_backend_free();
exit(1);
}
#endif
const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);

View file

@ -8,7 +8,6 @@
#include <stdint.h>
#include <stdio.h>
#include <stdbool.h>
#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef LLAMA_BUILD
@ -203,6 +202,9 @@ extern "C" {
};
struct llama_model_params {
// Array of layers to allocate to each node
const float * node_layer_weights;
int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
@ -358,6 +360,8 @@ extern "C" {
const char * path_model,
struct llama_model_params params);
LLAMA_API void llama_split_layers_weighted(struct llama_context * ctx, float device_weights[], size_t num_weights);
LLAMA_API void llama_free_model(struct llama_model * model);
LLAMA_API struct llama_context * llama_new_context_with_model(
@ -371,6 +375,9 @@ extern "C" {
LLAMA_API size_t llama_max_devices(void);
// Get the ID of this compute node, usually 0
// unless running MPI, in which case it is the rank of the node
LLAMA_API int llama_node_id(struct llama_context * ctx);
LLAMA_API bool llama_supports_mmap (void);
LLAMA_API bool llama_supports_mlock (void);
LLAMA_API bool llama_supports_gpu_offload(void);