pipeline parallelism demo
This commit is contained in:
parent
f172de03f1
commit
dbbaf82758
4 changed files with 261 additions and 189 deletions
|
@ -1149,7 +1149,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// warmup run
|
||||
if (t.n_prompt > 0) {
|
||||
test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
|
||||
//test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, std::min(t.n_prompt, 32), 0, t.n_batch, t.n_threads);
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
test_gen(ctx, 1, 0, t.n_threads);
|
||||
|
|
|
@ -319,6 +319,13 @@ struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) {
|
|||
return alloc->buffer;
|
||||
}
|
||||
|
||||
void ggml_tallocr_set_buffer(ggml_tallocr_t talloc, struct ggml_backend_buffer * buffer) {
|
||||
talloc->buffer = buffer;
|
||||
talloc->base = ggml_backend_buffer_get_base(buffer);
|
||||
talloc->alignment = ggml_backend_buffer_get_alignment(buffer);
|
||||
ggml_tallocr_reset(talloc);
|
||||
}
|
||||
|
||||
void ggml_tallocr_free(ggml_tallocr_t alloc) {
|
||||
if (alloc == NULL) {
|
||||
return;
|
||||
|
|
|
@ -59,6 +59,7 @@ GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_buft(struct ggml_backend_b
|
|||
GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
|
||||
|
||||
GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
|
||||
GGML_API void ggml_tallocr_set_buffer(ggml_tallocr_t talloc, struct ggml_backend_buffer * buffer);
|
||||
|
||||
GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc);
|
||||
GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc);
|
||||
|
|
439
llama.cpp
439
llama.cpp
|
@ -1663,7 +1663,9 @@ struct llama_context {
|
|||
std::vector<uint8_t> buf_compute_meta;
|
||||
ggml_backend_sched_t sched = nullptr;
|
||||
// allocator for the input tensors
|
||||
ggml_tallocr * alloc = nullptr;
|
||||
ggml_tallocr * alloc_cpu = nullptr;
|
||||
|
||||
std::vector<ggml_backend_buffer_t> buf_cpu_ub;
|
||||
|
||||
// temporary buffer for copying data to/from the backend
|
||||
std::vector<no_init<uint8_t>> buf_copy;
|
||||
|
@ -3208,7 +3210,8 @@ static bool llm_load_tensors(
|
|||
const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0);
|
||||
|
||||
// there is very little benefit to offloading the input layer, so always keep it on the CPU
|
||||
model.buft_input = llama_default_buffer_type_cpu(true);
|
||||
//model.buft_input = llama_default_buffer_type_cpu(true);
|
||||
model.buft_input = llama_default_buffer_type_offload(main_gpu);
|
||||
|
||||
model.buft_layer.resize(n_layer);
|
||||
|
||||
|
@ -5955,7 +5958,7 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
const auto & model = lctx.model;
|
||||
|
||||
// check if we should build the worst-case graph (for memory measurement)
|
||||
const bool worst_case = ggml_tallocr_is_measure(lctx.alloc);
|
||||
const bool worst_case = ggml_tallocr_is_measure(lctx.alloc_cpu);
|
||||
|
||||
// keep track of the input that has already been allocated
|
||||
bool alloc_inp_tokens = false;
|
||||
|
@ -5978,9 +5981,9 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
//
|
||||
|
||||
if (!alloc_inp_tokens && strcmp(name, "inp_tokens") == 0) {
|
||||
ggml_tallocr_alloc(lctx.alloc, cur);
|
||||
ggml_tallocr_alloc(lctx.alloc_cpu, cur);
|
||||
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc) && batch.token) {
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc_cpu) && batch.token) {
|
||||
const int64_t n_tokens = cur->ne[0];
|
||||
|
||||
ggml_backend_tensor_set(cur, batch.token, 0, n_tokens*ggml_element_size(cur));
|
||||
|
@ -5990,9 +5993,9 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
}
|
||||
|
||||
if (!alloc_inp_embd && strcmp(name, "inp_embd") == 0 && batch.embd) {
|
||||
ggml_tallocr_alloc(lctx.alloc, cur);
|
||||
ggml_tallocr_alloc(lctx.alloc_cpu, cur);
|
||||
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc) && batch.embd) {
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc_cpu) && batch.embd) {
|
||||
const int64_t n_embd = cur->ne[0];
|
||||
const int64_t n_tokens = cur->ne[1];
|
||||
|
||||
|
@ -6003,9 +6006,9 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
}
|
||||
|
||||
if (!alloc_inp_pos && strcmp(name, "inp_pos") == 0) {
|
||||
ggml_tallocr_alloc(lctx.alloc, cur);
|
||||
ggml_tallocr_alloc(lctx.alloc_cpu, cur);
|
||||
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc) && batch.pos) {
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc_cpu) && batch.pos) {
|
||||
const int64_t n_tokens = cur->ne[0];
|
||||
|
||||
static_assert(std::is_same<llama_pos, int32_t>::value, "llama_pos must be int32_t");
|
||||
|
@ -6016,9 +6019,9 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
}
|
||||
|
||||
if (!alloc_inp_KQ_mask && strcmp(name, "KQ_mask") == 0) {
|
||||
ggml_tallocr_alloc(lctx.alloc, cur);
|
||||
ggml_tallocr_alloc(lctx.alloc_cpu, cur);
|
||||
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc)) {
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc_cpu)) {
|
||||
const int64_t n_kv = cur->ne[0];
|
||||
const int64_t n_tokens = cur->ne[1];
|
||||
|
||||
|
@ -6056,9 +6059,9 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
}
|
||||
|
||||
if (!alloc_inp_K_shift && strcmp(name, "K_shift") == 0) {
|
||||
ggml_tallocr_alloc(lctx.alloc, cur);
|
||||
ggml_tallocr_alloc(lctx.alloc_cpu, cur);
|
||||
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc)) {
|
||||
if (!ggml_tallocr_is_measure(lctx.alloc_cpu)) {
|
||||
const int64_t n_ctx = cur->ne[0];
|
||||
|
||||
int32_t * data;
|
||||
|
@ -6161,10 +6164,11 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
//
|
||||
static int llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch) {
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
llama_batch all_batch) {
|
||||
|
||||
if (n_tokens == 0) {
|
||||
const uint32_t n_tokens_all = all_batch.n_tokens;
|
||||
|
||||
if (n_tokens_all == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
@ -6173,12 +6177,11 @@ static int llama_decode_internal(
|
|||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
const auto n_batch = cparams.n_batch;
|
||||
//const auto n_batch = cparams.n_batch;
|
||||
|
||||
GGML_ASSERT(n_tokens <= n_batch);
|
||||
GGML_ASSERT((!all_batch.token && all_batch.embd) || (all_batch.token && !all_batch.embd)); // NOLINT
|
||||
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_ctx);
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
|
@ -6188,205 +6191,255 @@ static int llama_decode_internal(
|
|||
//ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
|
||||
#endif
|
||||
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
|
||||
// helpers for smoother batch API transition
|
||||
// after deprecating the llama_eval calls, these will be removed
|
||||
std::vector<llama_pos> pos;
|
||||
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id_arr;
|
||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||
auto & logits_out = lctx.logits;
|
||||
|
||||
if (batch.pos == nullptr) {
|
||||
pos.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
|
||||
}
|
||||
|
||||
batch.pos = pos.data();
|
||||
if (all_batch.logits) {
|
||||
logits_out.resize(n_vocab * n_tokens_all);
|
||||
} else if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * n_tokens_all);
|
||||
} else {
|
||||
logits_out.resize(n_vocab);
|
||||
}
|
||||
|
||||
if (batch.seq_id == nullptr) {
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_id.resize(n_tokens);
|
||||
seq_id_arr.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_seq_id[i] = 1;
|
||||
seq_id[i].resize(1);
|
||||
seq_id[i][0] = batch.all_seq_id;
|
||||
seq_id_arr[i] = seq_id[i].data();
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
auto & logits_valid = lctx.logits_valid;
|
||||
logits_valid.clear();
|
||||
logits_valid.resize(n_tokens_all);
|
||||
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
batch.seq_id = seq_id_arr.data();
|
||||
}
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, batch);
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
|
||||
|
||||
// the embeddings could be the second to last tensor, or the third to last tensor
|
||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
||||
if (strcmp(embeddings->name, "result_norm") != 0) {
|
||||
embeddings = gf->nodes[gf->n_nodes - 3];
|
||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||
}
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
// TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
|
||||
// we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
|
||||
// with the BLAS calls. need a better solution
|
||||
if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
|
||||
n_threads = std::min(4, n_threads);
|
||||
}
|
||||
|
||||
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1;
|
||||
if (ggml_cpu_has_cublas() && fully_offloaded) {
|
||||
n_threads = 1;
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
|
||||
logits_out.clear();
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ggml_backend_is_metal(lctx.backend_metal)) {
|
||||
ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
|
||||
}
|
||||
#endif
|
||||
const uint32_t n_microbatch = 256;
|
||||
|
||||
if (lctx.backend_cpu != nullptr) {
|
||||
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
|
||||
}
|
||||
ggml_backend_sched_graph_compute(lctx.sched, gf);
|
||||
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_microbatch) {
|
||||
const uint32_t n_tokens = std::min(n_microbatch, n_tokens_all - cur_token);
|
||||
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||
llama_batch batch = {
|
||||
/* .n_tokens = */ (int32_t) n_tokens,
|
||||
/* .token = */ all_batch.token ? all_batch.token + cur_token : nullptr,
|
||||
/* .embd = */ all_batch.embd ? all_batch.embd + cur_token*n_embd : nullptr,
|
||||
/* .pos = */ all_batch.pos ? all_batch.pos + cur_token : nullptr,
|
||||
/* .n_seq_id = */ all_batch.n_seq_id ? all_batch.n_seq_id + cur_token : nullptr,
|
||||
/* .seq_id = */ all_batch.seq_id ? all_batch.seq_id + cur_token : nullptr,
|
||||
/* .logits = */ all_batch.logits ? all_batch.logits + cur_token : nullptr,
|
||||
/* .all_pos_0 = */ all_batch.all_pos_0 + (llama_pos) cur_token*all_batch.all_pos_1,
|
||||
/* .all_pos_1 = */ all_batch.all_pos_1,
|
||||
/* .all_seq_id = */ all_batch.all_seq_id,
|
||||
};
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
|
||||
#endif
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
if (kv_self.has_shift) {
|
||||
kv_self.has_shift = false;
|
||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||
kv_self.cells[i].delta = 0;
|
||||
// helpers for smoother batch API transition
|
||||
// after deprecating the llama_eval calls, these will be removed
|
||||
std::vector<llama_pos> pos;
|
||||
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id_arr;
|
||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||
|
||||
if (batch.pos == nullptr) {
|
||||
pos.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
|
||||
}
|
||||
|
||||
batch.pos = pos.data();
|
||||
}
|
||||
|
||||
kv_self.head += n_tokens;
|
||||
if (batch.seq_id == nullptr) {
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_id.resize(n_tokens);
|
||||
seq_id_arr.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_seq_id[i] = 1;
|
||||
seq_id[i].resize(1);
|
||||
seq_id[i][0] = batch.all_seq_id;
|
||||
seq_id_arr[i] = seq_id[i].data();
|
||||
}
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
batch.seq_id = seq_id_arr.data();
|
||||
}
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// print timing information per ggml operation (for debugging purposes)
|
||||
// requires GGML_PERF to be defined
|
||||
ggml_graph_print(gf);
|
||||
#endif
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
// extract logits
|
||||
// TODO: do not compute and extract logits if only embeddings are needed
|
||||
// need to update the graphs to skip "result_output"
|
||||
{
|
||||
auto & logits_out = lctx.logits;
|
||||
|
||||
#ifndef NDEBUG
|
||||
auto & logits_valid = lctx.logits_valid;
|
||||
logits_valid.clear();
|
||||
logits_valid.resize(n_tokens);
|
||||
|
||||
logits_out.clear();
|
||||
#endif
|
||||
|
||||
ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res);
|
||||
GGML_ASSERT(res_backend != nullptr);
|
||||
if (batch.logits) {
|
||||
logits_out.resize(n_vocab * n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
|
||||
#ifndef NDEBUG
|
||||
logits_valid[i] = true;
|
||||
#endif
|
||||
}
|
||||
} else if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * n_tokens);
|
||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
|
||||
#ifndef NDEBUG
|
||||
std::fill(logits_valid.begin(), logits_valid.end(), true);
|
||||
#endif
|
||||
} else {
|
||||
logits_out.resize(n_vocab);
|
||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
|
||||
#ifndef NDEBUG
|
||||
logits_valid[0] = true;
|
||||
#endif
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find a slot in the cache", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
int i_ub = cur_token / n_microbatch;
|
||||
size_t n_buf = lctx.buf_cpu_ub.size();
|
||||
if (i_ub != 0 && i_ub % n_buf == 0) {
|
||||
// sync all backends
|
||||
printf("not enough buffers, syncing now\n");
|
||||
// TODO: ggml_backend_sched_synchronize()
|
||||
for (auto * backend : lctx.backends) {
|
||||
ggml_backend_synchronize(backend);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tallocr_set_buffer(lctx.alloc_cpu, lctx.buf_cpu_ub[i_ub % n_buf]);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, batch);
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
|
||||
|
||||
// the embeddings could be the second to last tensor, or the third to last tensor
|
||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
||||
if (strcmp(embeddings->name, "result_norm") != 0) {
|
||||
embeddings = gf->nodes[gf->n_nodes - 3];
|
||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||
}
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
// TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
|
||||
// we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
|
||||
// with the BLAS calls. need a better solution
|
||||
if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
|
||||
n_threads = std::min(4, n_threads);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (lctx.backend_metal != nullptr) {
|
||||
ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (lctx.backend_cpu != nullptr) {
|
||||
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
|
||||
}
|
||||
|
||||
ggml_backend_sched_graph_compute(lctx.sched, gf);
|
||||
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
|
||||
#endif
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
if (kv_self.has_shift) {
|
||||
kv_self.has_shift = false;
|
||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||
kv_self.cells[i].delta = 0;
|
||||
}
|
||||
}
|
||||
|
||||
kv_self.head += n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// print timing information per ggml operation (for debugging purposes)
|
||||
// requires GGML_PERF to be defined
|
||||
ggml_graph_print(gf);
|
||||
#endif
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
// extract logits
|
||||
// TODO: do not compute and extract logits if only embeddings are needed
|
||||
// need to update the graphs to skip "result_output"
|
||||
{
|
||||
ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res);
|
||||
GGML_ASSERT(res_backend != nullptr);
|
||||
if (batch.logits) {
|
||||
//logits_out.resize(n_vocab * n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + n_vocab*(cur_token + i), n_vocab*i*sizeof(float), n_vocab*sizeof(float));
|
||||
#ifndef NDEBUG
|
||||
logits_valid[i] = true;
|
||||
#endif
|
||||
}
|
||||
} else if (lctx.logits_all) {
|
||||
//logits_out.resize(n_vocab * n_tokens);
|
||||
//ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
|
||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + cur_token*n_vocab, 0, n_vocab*n_tokens*sizeof(float));
|
||||
#ifndef NDEBUG
|
||||
std::fill(logits_valid.begin(), logits_valid.end(), true);
|
||||
#endif
|
||||
} else {
|
||||
if (cur_token + n_tokens >= n_tokens_all) {
|
||||
//logits_out.resize(n_vocab);
|
||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float));
|
||||
}
|
||||
//ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float));
|
||||
#ifndef NDEBUG
|
||||
logits_valid[0] = true;
|
||||
#endif
|
||||
}
|
||||
//ggml_backend_synchronize(res_backend);
|
||||
}
|
||||
|
||||
// FIXME
|
||||
// extract embeddings
|
||||
if (!lctx.embedding.empty()) {
|
||||
GGML_ASSERT(!"not implemented");
|
||||
auto & embedding_out = lctx.embedding;
|
||||
|
||||
embedding_out.resize(n_embd);
|
||||
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float));
|
||||
//ggml_backend_synchronize(embeddings_backend);
|
||||
}
|
||||
ggml_backend_synchronize(res_backend);
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (!lctx.embedding.empty()) {
|
||||
auto & embedding_out = lctx.embedding;
|
||||
|
||||
embedding_out.resize(n_embd);
|
||||
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float));
|
||||
ggml_backend_synchronize(embeddings_backend);
|
||||
// TODO: ggml_backend_sched_synchronize()
|
||||
for (auto * backend : lctx.backends) {
|
||||
ggml_backend_synchronize(backend);
|
||||
}
|
||||
|
||||
// measure the performance only for the single-token evals
|
||||
if (n_tokens == 1) {
|
||||
if (n_tokens_all == 1) {
|
||||
lctx.t_eval_us += ggml_time_us() - t_start_us;
|
||||
lctx.n_eval++;
|
||||
}
|
||||
else if (n_tokens > 1) {
|
||||
else if (n_tokens_all > 1) {
|
||||
lctx.t_p_eval_us += ggml_time_us() - t_start_us;
|
||||
lctx.n_p_eval += n_tokens;
|
||||
lctx.n_p_eval += n_tokens_all;
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
|
@ -9402,7 +9455,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
|
||||
ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
|
||||
ctx->alloc_cpu = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
|
||||
|
||||
// build worst-case graph
|
||||
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
|
||||
|
@ -9415,7 +9468,17 @@ struct llama_context * llama_new_context_with_model(
|
|||
// note: the number of splits during measure is higher than during inference due to the kv shift
|
||||
int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
|
||||
LLAMA_LOG_INFO("%s: graph splits (measure): %d\n", __func__, n_splits);
|
||||
ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
|
||||
ctx->alloc_cpu = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
|
||||
|
||||
// duplicate cpu buffers for microbatching
|
||||
ggml_backend_buffer_t buf_cpu = ggml_tallocr_get_buffer(ctx->alloc_cpu);
|
||||
size_t buf_size = ggml_backend_buffer_get_size(buf_cpu);
|
||||
ctx->buf_cpu_ub.push_back(buf_cpu);
|
||||
int n_ub = 64;
|
||||
for (int i = 1; i < n_ub; ++i) {
|
||||
ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_size);
|
||||
ctx->buf_cpu_ub.push_back(buf);
|
||||
}
|
||||
|
||||
for (ggml_backend_t backend : ctx->backends) {
|
||||
ggml_backend_buffer_t buf = ggml_backend_sched_get_buffer(ctx->sched, backend);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue