sched : support async weight copy

This commit is contained in:
slaren 2024-05-14 20:41:33 +02:00
parent 8f7080bf48
commit 5de9b743f8
6 changed files with 181 additions and 18 deletions

View file

@ -1358,6 +1358,7 @@ int main(int argc, char ** argv) {
}
p->print_test(t);
fflush(p->fout);
llama_print_timings(ctx);

View file

@ -114,6 +114,8 @@ extern "C" {
void (*GGML_CALL event_record) (ggml_backend_event_t event);
void (*GGML_CALL event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
void (*GGML_CALL event_synchronize) (ggml_backend_event_t event);
ggml_backend_t (*GGML_CALL backend_dup)(ggml_backend_t backend);
};
struct ggml_backend {

View file

@ -180,6 +180,13 @@ void ggml_backend_free(ggml_backend_t backend) {
backend->iface.free(backend);
}
ggml_backend_t ggml_backend_dup(ggml_backend_t backend) {
if (backend->iface.backend_dup) {
return backend->iface.backend_dup(backend);
}
return backend;
}
ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
return backend->iface.get_default_buffer_type(backend);
}
@ -855,6 +862,7 @@ static struct ggml_backend_i cpu_backend_i = {
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .event_synchronize = */ NULL,
/* .backend_dup = */ NULL,
};
static ggml_guid_t ggml_backend_cpu_guid(void) {
@ -1026,16 +1034,34 @@ static bool ggml_is_view_op(enum ggml_op op) {
#define GGML_SCHED_MAX_COPIES 4
#endif
#ifndef GGML_SCHED_MAX_COPY_STREAMS
#define GGML_SCHED_MAX_COPY_STREAMS 8
#endif
struct ggml_backend_sched_split {
int backend_id;
int i_start;
int i_end;
// input tensors from other backends
struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
int n_inputs;
// copy stream to use to copy the inputs that are weights (-1 = no copy stream)
int w_copy_stream_id;
// graph view of this split
struct ggml_cgraph graph;
};
struct ggml_backend_sched_copy_stream {
ggml_backend_t stream;
ggml_backend_buffer_t buffer;
ggml_backend_event_t event_copy;
ggml_backend_event_t event_use;
size_t max_size;
};
struct ggml_backend_sched {
bool is_reset; // true if the scheduler has been reset since the last graph split
bool is_alloc;
@ -1046,6 +1072,9 @@ struct ggml_backend_sched {
ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS];
ggml_gallocr_t galloc;
struct ggml_backend_sched_copy_stream copy_streams[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPY_STREAMS];
int cur_copy_stream[GGML_SCHED_MAX_BACKENDS];
// hash keys of the nodes in the graph
struct ggml_hash_set hash_set;
// hash values
@ -1228,6 +1257,14 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
//#define DEBUG_PASS3
//#define DEBUG_PASS4
static void init_split(ggml_backend_sched_t sched, struct ggml_backend_sched_split * split, int backend_id, int i_start) {
split->backend_id = backend_id;
split->i_start = i_start;
split->i_end = -1;
split->n_inputs = 0;
split->w_copy_stream_id = -1;
}
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset splits
@ -1406,19 +1443,17 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
// pass 4: split graph, find tensors that need to be copied
{
int i_split = 0;
int cur_backend_id = 0;
struct ggml_backend_sched_split * split = &sched->splits[0];
// find the backend of the first split, skipping view ops
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
if (!ggml_is_view_op(node->op)) {
split->backend_id = tensor_backend_id(node);
cur_backend_id = tensor_backend_id(node);
break;
}
}
split->i_start = 0;
split->n_inputs = 0;
memset(split->inputs, 0, sizeof(split->inputs)); //HACK
int cur_backend_id = split->backend_id;
init_split(sched, split, cur_backend_id, 0);
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
@ -1433,6 +1468,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
// check if we should start a new split based on the sources of the current node
bool need_new_split = false;
if (node_backend_id == cur_backend_id && split->n_inputs > 0) {
if (split->w_copy_stream_id != -1) {
// the previous op used a weight copy stream, start a new split to allow the next copy to start immediately after the op
need_new_split = true;
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
@ -1452,7 +1492,6 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
const size_t id = hash_id(src);
int src_backend_id = sched->tensor_backend_id[id];
if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL) {
//printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
need_new_split = true;
break;
}
@ -1470,10 +1509,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
}
GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS);
split = &sched->splits[i_split];
split->backend_id = node_backend_id;
split->i_start = i;
split->n_inputs = 0;
cur_backend_id = node_backend_id;
init_split(sched, split, cur_backend_id, i);
}
// find inputs that are not on the same backend
@ -1529,6 +1566,13 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
int n_inputs = split->n_inputs++;
GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
split->inputs[n_inputs] = src;
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS && split->w_copy_stream_id == -1 && GGML_SCHED_MAX_COPY_STREAMS > 0) {
split->w_copy_stream_id = sched->cur_copy_stream[cur_backend_id];
sched->copy_streams[cur_backend_id][split->w_copy_stream_id].max_size = MAX(
sched->copy_streams[cur_backend_id][split->w_copy_stream_id].max_size,
ggml_backend_buft_get_alloc_size(sched->bufts[cur_backend_id], src));
sched->cur_copy_stream[cur_backend_id] = (sched->cur_copy_stream[cur_backend_id] + 1) % GGML_SCHED_MAX_COPY_STREAMS;
}
}
node->src[j] = sched->tensor_copies[id][cur_backend_id][sched->cur_copy];
}
@ -1540,6 +1584,10 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
#ifdef DEBUG_PASS4
fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
#endif
if (getenv("GGML_DEBUG_SCHED")) {
fprintf(stderr, "SPLIT GRAPH\n");
ggml_backend_sched_print_assignments(sched, graph);
}
// create copies of the graph for each split
// TODO: avoid this copy
@ -1613,6 +1661,25 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
}
static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
// allocate weights in the copy buffers
for (int s = 0; s < sched->n_splits; s++) {
struct ggml_backend_sched_split * split = &sched->splits[s];
if (split->w_copy_stream_id != -1) {
struct ggml_backend_sched_copy_stream * stream = &sched->copy_streams[split->backend_id][split->w_copy_stream_id];
ggml_backend_buffer_t buffer = stream->buffer;
if (buffer == NULL) {
continue;
}
for (int j = 0; j < split->n_inputs; j++) {
struct ggml_tensor * input = split->inputs[j];
if (input->buffer != NULL && input->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split->backend_id][sched->cur_copy];
ggml_backend_tensor_alloc(buffer, input_cpy, ggml_backend_buffer_get_base(buffer));
}
}
}
}
// allocate graph
if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) {
// the re-allocation may cause the split inputs to be moved to a different address
@ -1637,6 +1704,11 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
struct ggml_backend_sched_split * split = &splits[i];
int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id];
struct ggml_backend_sched_copy_stream * stream = NULL;
if (split->w_copy_stream_id != -1) {
stream = &sched->copy_streams[split_backend_id][split->w_copy_stream_id];
}
// copy the input tensors to the split backend
for (int j = 0; j < split->n_inputs; j++) {
@ -1644,7 +1716,9 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
struct ggml_tensor * input = split->inputs[j];
struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy];
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
if (input->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS && stream && stream->stream) {
ggml_backend_tensor_copy_async(input_backend, stream->stream, input, input_cpy);
} else if (input->flags & GGML_TENSOR_FLAG_INPUT) {
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
@ -1663,6 +1737,11 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}
}
if (stream && stream->stream) {
ggml_backend_event_record(stream->event_copy);
ggml_backend_event_wait(split_backend, stream->event_copy);
}
if (!sched->callback_eval) {
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
if (ec != GGML_STATUS_SUCCESS) {
@ -1702,6 +1781,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}
}
// record event of this copy stream
if (stream && stream->stream) {
ggml_backend_event_record(stream->event_use);
ggml_backend_event_wait(stream->stream, stream->event_use);
}
// record the event of this copy
if (split->n_inputs > 0) {
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
@ -1766,11 +1851,19 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
if (sched == NULL) {
return;
}
for (int b = 0; b < sched->n_backends; b++) {
for (int c = 0; c < sched->n_copies; c++) {
ggml_backend_event_free(sched->events[b][c]);
}
for (int s = 0; s < GGML_SCHED_MAX_COPY_STREAMS; s++) {
ggml_backend_buffer_free(sched->copy_streams[b][s].buffer);
ggml_backend_event_free(sched->copy_streams[b][s].event_copy);
ggml_backend_event_free(sched->copy_streams[b][s].event_use);
ggml_backend_free(sched->copy_streams[b][s].stream);
}
}
ggml_gallocr_free(sched->galloc);
ggml_free(sched->ctx);
free(sched->splits);
@ -1789,6 +1882,7 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT
memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size);
memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size);
memset(sched->cur_copy_stream, 0, sizeof(sched->cur_copy_stream[0]) * sched->n_backends);
sched->is_reset = true;
}
@ -1800,7 +1894,46 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
ggml_backend_sched_split_graph(sched, measure_graph);
// TODO: extract this to a separate function
// allocate tensor copy streams
for (int b = 0; b < sched->n_backends; b++) {
for (int j = 0; j < GGML_SCHED_MAX_COPY_STREAMS; j++) {
struct ggml_backend_sched_copy_stream * stream = &sched->copy_streams[b][j];
if (stream->max_size > 0) {
// backend
if (!stream->stream) {
stream->stream = ggml_backend_dup(sched->backends[b]);
}
if (!stream->stream) {
continue;
}
// events
if (!stream->event_copy) {
stream->event_copy = ggml_backend_event_new(stream->stream);
}
if (!stream->event_use) {
stream->event_use = ggml_backend_event_new(sched->backends[b]);
}
if (!stream->event_copy || !stream->event_use) {
continue;
}
// buffer
if (!stream->buffer || ggml_backend_buffer_get_size(stream->buffer) < stream->max_size) {
ggml_backend_buffer_free(stream->buffer);
stream->buffer = ggml_backend_buft_alloc_buffer(sched->bufts[b], stream->max_size);
if (stream->buffer == NULL) {
fprintf(stderr, "%s: failed to allocate buffer for copy stream\n", __func__);
return false;
}
}
}
}
}
if (!ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
return false;
}
@ -1868,7 +2001,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
size_t size = ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
for (int i = 0; i < GGML_SCHED_MAX_COPY_STREAMS; i++) {
if (sched->copy_streams[backend_index][i].buffer == NULL) {
continue;
}
size += ggml_backend_buffer_get_size(sched->copy_streams[backend_index][i].buffer);
}
return size;
}
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {

View file

@ -50,9 +50,10 @@ extern "C" {
// Backend
//
GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend);
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API void ggml_backend_free(ggml_backend_t backend);
GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend);
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API void ggml_backend_free(ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_dup(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);

View file

@ -2920,6 +2920,12 @@ static void ggml_backend_cuda_event_synchronize(ggml_backend_event_t event) {
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
}
static ggml_backend_t ggml_backend_cuda_dup(ggml_backend_t backend) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
return ggml_backend_cuda_init(cuda_ctx->device);
}
static ggml_backend_i ggml_backend_cuda_interface = {
/* .get_name = */ ggml_backend_cuda_name,
/* .free = */ ggml_backend_cuda_free,
@ -2939,6 +2945,7 @@ static ggml_backend_i ggml_backend_cuda_interface = {
/* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait,
/* .event_synchronize = */ ggml_backend_cuda_event_synchronize,
/* .backend_dup = */ ggml_backend_cuda_dup,
};
static ggml_guid_t ggml_backend_cuda_guid() {

View file

@ -6360,8 +6360,6 @@ static void llm_build_kv_store(
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
cb(k_cache_view, "k_cache_view", il);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
@ -6380,7 +6378,19 @@ static void llm_build_kv_store(
}
cb(v_cache_view, "v_cache_view", il);
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
struct ggml_tensor * k_cur_cast = ggml_cast(ctx, k_cur, k_cache_view->type);
struct ggml_tensor * v_cur_cast = ggml_cast(ctx, v_cur, v_cache_view->type);
ggml_build_forward_expand(graph, k_cur_cast);
ggml_build_forward_expand(graph, v_cur_cast);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur_cast, k_cache_view));
//ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_cast, v_cache_view));
//ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
}
static struct ggml_tensor * llm_build_norm(