fix CUDA split buffers
This commit is contained in:
parent
3a774427ae
commit
c0fe6298ae
3 changed files with 21 additions and 11 deletions
|
@ -1051,8 +1051,9 @@ struct ggml_backend_sched {
|
||||||
struct ggml_cgraph * graph;
|
struct ggml_cgraph * graph;
|
||||||
|
|
||||||
// graph splits
|
// graph splits
|
||||||
struct ggml_backend_sched_split splits[GGML_SCHED_MAX_SPLITS];
|
struct ggml_backend_sched_split * splits;
|
||||||
int n_splits;
|
int n_splits;
|
||||||
|
int splits_capacity;
|
||||||
|
|
||||||
// pipeline parallelism support
|
// pipeline parallelism support
|
||||||
int n_copies;
|
int n_copies;
|
||||||
|
@ -1443,6 +1444,10 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||||
if (node_backend_id != cur_backend_id || offload) {
|
if (node_backend_id != cur_backend_id || offload) {
|
||||||
split->i_end = i;
|
split->i_end = i;
|
||||||
i_split++;
|
i_split++;
|
||||||
|
if (i_split >= sched->splits_capacity) {
|
||||||
|
sched->splits_capacity *= 2;
|
||||||
|
sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
|
||||||
|
}
|
||||||
GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS);
|
GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS);
|
||||||
split = &sched->splits[i_split];
|
split = &sched->splits[i_split];
|
||||||
split->backend_id = node_backend_id;
|
split->backend_id = node_backend_id;
|
||||||
|
@ -1711,7 +1716,9 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
||||||
|
|
||||||
sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
|
sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
|
||||||
|
|
||||||
GGML_ASSERT(sched->n_copies <= GGML_SCHED_MAX_COPIES);
|
const int initial_splits_capacity = 16;
|
||||||
|
sched->splits = calloc(sizeof(sched->splits[0]), initial_splits_capacity);
|
||||||
|
sched->splits_capacity = initial_splits_capacity;
|
||||||
|
|
||||||
for (int b = 0; b < n_backends; b++) {
|
for (int b = 0; b < n_backends; b++) {
|
||||||
sched->backends[b] = backends[b];
|
sched->backends[b] = backends[b];
|
||||||
|
@ -1742,6 +1749,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
|
||||||
}
|
}
|
||||||
ggml_gallocr_free(sched->galloc);
|
ggml_gallocr_free(sched->galloc);
|
||||||
ggml_free(sched->ctx);
|
ggml_free(sched->ctx);
|
||||||
|
free(sched->splits);
|
||||||
free(sched->hash_set.keys);
|
free(sched->hash_set.keys);
|
||||||
free(sched->tensor_backend_id);
|
free(sched->tensor_backend_id);
|
||||||
free(sched->tensor_copies);
|
free(sched->tensor_copies);
|
||||||
|
|
14
ggml-cuda.cu
14
ggml-cuda.cu
|
@ -10755,6 +10755,8 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
|
||||||
};
|
};
|
||||||
|
|
||||||
GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
|
GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
|
||||||
|
ggml_init_cublas();
|
||||||
|
|
||||||
// FIXME: this is not thread safe
|
// FIXME: this is not thread safe
|
||||||
if (device >= ggml_backend_cuda_get_device_count()) {
|
if (device >= ggml_backend_cuda_get_device_count()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -11039,6 +11041,8 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface
|
||||||
};
|
};
|
||||||
|
|
||||||
GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split) {
|
GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split) {
|
||||||
|
ggml_init_cublas();
|
||||||
|
|
||||||
// FIXME: this is not thread safe
|
// FIXME: this is not thread safe
|
||||||
static std::map<std::array<float, GGML_CUDA_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
|
static std::map<std::array<float, GGML_CUDA_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
|
||||||
|
|
||||||
|
@ -11389,15 +11393,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
||||||
const ggml_tensor * dst = op;
|
|
||||||
|
|
||||||
const int min_batch_size = 32;
|
const int min_batch_size = 32;
|
||||||
|
|
||||||
if (dst->ne[1] > min_batch_size && dst->op != GGML_OP_GET_ROWS) {
|
return op->ne[1] > min_batch_size && op->op != GGML_OP_GET_ROWS;
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) {
|
static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) {
|
||||||
|
@ -11476,7 +11474,7 @@ static ggml_guid_t ggml_backend_cuda_guid() {
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
|
GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
|
||||||
ggml_init_cublas(); // TODO: remove from ggml.c
|
ggml_init_cublas();
|
||||||
|
|
||||||
if (device < 0 || device >= ggml_cuda_get_device_count()) {
|
if (device < 0 || device >= ggml_cuda_get_device_count()) {
|
||||||
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
|
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
|
||||||
|
|
|
@ -5039,7 +5039,11 @@ static bool llm_load_tensors(
|
||||||
ml.get_mapping_range(&first, &last, ctx);
|
ml.get_mapping_range(&first, &last, ctx);
|
||||||
buf = ggml_backend_cpu_buffer_from_ptr((char *) ml.mapping->addr + first, last - first);
|
buf = ggml_backend_cpu_buffer_from_ptr((char *) ml.mapping->addr + first, last - first);
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
ggml_backend_cuda_register_host_buffer((char *) ml.mapping->addr + first, last - first);
|
if (n_layer >= n_gpu_layers) {
|
||||||
|
ggml_backend_cuda_register_host_buffer(
|
||||||
|
ggml_backend_buffer_get_base(buf),
|
||||||
|
ggml_backend_buffer_get_size(buf));
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue