cuda : add ggml-backend split buffer support

This commit is contained in:
slaren 2024-01-06 23:07:43 +01:00
parent ece0b0d855
commit 2f2c36799d
4 changed files with 318 additions and 74 deletions

View file

@ -1015,9 +1015,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
ggml_tallocr_t node_allocr = node_allocr(node);
if (node_allocr != NULL) {
if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
// skip cpu
cur_allocr = NULL;
}
else {
} else {
cur_allocr = node_allocr;
}
} else {
@ -1038,9 +1038,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
ggml_tallocr_t node_allocr = node_allocr(node);
if (node_allocr != NULL) {
if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
// skip cpu
cur_allocr = NULL;
}
else {
} else {
cur_allocr = node_allocr;
}
} else {
@ -1274,7 +1274,7 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
GGML_ASSERT(false);
}
// TODO: avoid this copy if it was already copied in a previous split, and the input didn't change
// this is important to avoid copying constants such as KQ_mask and inp_pos multiple time
// this is important to avoid copying constants such as KQ_mask and inp_pos multiple times
ggml_backend_tensor_copy(input, input_cpy);
}
// ggml_backend_synchronize(split_backend);

View file

@ -10,7 +10,11 @@
#include <stdio.h>
#include <string>
#include <vector>
#include <map>
#include <array>
#include "ggml-cuda.h"
#include "ggml.h"
#include "ggml-backend-impl.h"
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
@ -114,10 +118,6 @@
#endif // defined(GGML_USE_HIPBLAS)
#include "ggml-cuda.h"
#include "ggml.h"
#include "ggml-backend-impl.h"
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_OFFSET_AMD 1000000
@ -546,7 +546,7 @@ static void ggml_cuda_set_device(const int device) {
static int g_device_count = -1;
static int g_main_device = 0;
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static std::array<float, GGML_CUDA_MAX_DEVICES> g_default_tensor_split = {};
struct cuda_device_capabilities {
int cc; // compute capability
@ -6854,8 +6854,9 @@ void ggml_init_cublas() {
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
g_tensor_split[id] = total_vram;
g_default_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
g_device_caps[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
#else
@ -6863,7 +6864,7 @@ void ggml_init_cublas() {
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
for (int id = 0; id < g_device_count; ++id) {
g_tensor_split[id] /= total_vram;
g_default_tensor_split[id] /= total_vram;
}
for (int id = 0; id < g_device_count; ++id) {
@ -6887,33 +6888,6 @@ void ggml_init_cublas() {
}
}
// TODO: cleanup this after the split buffer type is implemented
#if 0
void ggml_cuda_set_tensor_split(const float * tensor_split) {
if (tensor_split == nullptr) {
return;
}
bool all_zero = true;
for (int i = 0; i < g_device_count; ++i) {
if (tensor_split[i] != 0.0f) {
all_zero = false;
break;
}
}
if (all_zero) {
return;
}
float split_sum = 0.0f;
for (int i = 0; i < g_device_count; ++i) {
g_tensor_split[i] = split_sum;
split_sum += tensor_split[i];
}
for (int i = 0; i < g_device_count; ++i) {
g_tensor_split[i] /= split_sum;
}
}
#endif
void * ggml_cuda_host_malloc(size_t size) {
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
return nullptr;
@ -7365,11 +7339,11 @@ static void ggml_cuda_op_mul_mat_q(
(void) src1_ddf_i;
}
static int64_t get_row_rounding(ggml_type type) {
static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
int64_t min_compute_capability = INT_MAX;
int64_t max_compute_capability = INT_MIN;
for (int id = 0; id < g_device_count; ++id) {
if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
if (tensor_split[id] < (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) {
if (min_compute_capability > g_device_caps[id].cc) {
min_compute_capability = g_device_caps[id].cc;
}
@ -7426,6 +7400,21 @@ static int64_t get_row_rounding(ggml_type type) {
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
const int64_t nrows = ggml_nrows(tensor);
const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
*row_low = id == 0 ? 0 : nrows*tensor_split[id];
*row_low -= *row_low % rounding;
if (id == g_device_count - 1) {
*row_high = nrows;
} else {
*row_high = nrows*tensor_split[id + 1];
*row_high -= *row_high % rounding;
}
}
static void ggml_cuda_op_mul_mat_vec_q(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
@ -8023,6 +8012,11 @@ static void ggml_cuda_set_peer_access(const int n_tokens) {
peer_access_enabled = enable_peer_access;
}
// FIXME: move this somewhere else
struct ggml_backend_cuda_split_buffer_type_context {
std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
};
static void ggml_cuda_op_mul_mat(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
const bool convert_src1_to_q8_1) {
@ -8074,6 +8068,14 @@ static void ggml_cuda_op_mul_mat(
GGML_ASSERT(!(split && ne03 > 1));
GGML_ASSERT(!(split && ne02 < ne12));
std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
if (split) {
// TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_GPU_SPLIT check
// GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
tensor_split = buft_ctx->tensor_split;
}
struct dev_data {
cuda_pool_alloc<char> src0_dd_alloc;
cuda_pool_alloc<float> src1_ddf_alloc;
@ -8101,17 +8103,17 @@ static void ggml_cuda_op_mul_mat(
// for multi GPU, get the row boundaries from tensor split
// and round to mul_mat_q tile sizes
if (split) {
const int64_t rounding = get_row_rounding(src0->type);
const int64_t rounding = get_row_rounding(src0->type, tensor_split);
if (id != 0) {
dev[id].row_low = ne01*g_tensor_split[id];
dev[id].row_low = ne01*tensor_split[id];
if (dev[id].row_low < ne01) {
dev[id].row_low -= dev[id].row_low % rounding;
}
}
if (id != g_device_count - 1) {
dev[id].row_high = ne01*g_tensor_split[id + 1];
dev[id].row_high = ne01*tensor_split[id + 1];
if (dev[id].row_high < ne01) {
dev[id].row_high -= dev[id].row_high % rounding;
}
@ -8657,10 +8659,17 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
int64_t min_compute_capability = INT_MAX;
for (int id = 0; id < g_device_count; ++id) {
if (min_compute_capability > g_device_caps[id].cc && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
min_compute_capability = g_device_caps[id].cc;
if (split) {
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
auto & tensor_split = buft_ctx->tensor_split;
for (int id = 0; id < g_device_count; ++id) {
if (min_compute_capability > g_device_caps[id].cc && tensor_split[id] < (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) {
min_compute_capability = g_device_caps[id].cc;
}
}
} else {
min_compute_capability = g_device_caps[g_main_device].cc;
}
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
@ -9435,8 +9444,6 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[ctx->device][0]));
}
}
UNUSED(buffer);
}
static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
@ -9469,7 +9476,7 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
CUDA_CHECK(cudaDeviceSynchronize());
}
static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
static struct ggml_backend_buffer_i ggml_cuda_backend_buffer_interface = {
/* .get_name = */ ggml_backend_cuda_buffer_get_name,
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
@ -9510,7 +9517,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda(buft_ctx->device, dev_ptr);
return ggml_backend_buffer_init(buft, cuda_backend_buffer_interface, ctx, size);
return ggml_backend_buffer_init(buft, ggml_cuda_backend_buffer_interface, ctx, size);
}
static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@ -9560,6 +9567,7 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
};
ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
// FIXME: this is not thread safe
static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
static bool ggml_backend_cuda_buffer_type_initialized = false;
@ -9577,16 +9585,255 @@ ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
return &ggml_backend_cuda_buffer_types[device];
}
// cuda split buffer
struct ggml_backend_cuda_split_buffer_context {
~ggml_backend_cuda_split_buffer_context() {
for (ggml_tensor_extra_gpu * extra : tensor_extras) {
for (int id = 0; id < g_device_count; ++id) {
for (int64_t is = 0; is < MAX_STREAMS; ++is) {
CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
}
CUDA_CHECK(cudaFree(extra->data_device[id]));
}
delete extra;
}
}
std::vector<ggml_tensor_extra_gpu *> tensor_extras;
};
static const char * ggml_backend_cuda_split_buffer_get_name(ggml_backend_buffer_t buffer) {
return GGML_CUDA_NAME "_Split";
UNUSED(buffer);
}
static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
delete ctx;
}
static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
// the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
return (void *)0x1000;
UNUSED(buffer);
}
static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
const int64_t ne0 = tensor->ne[0];
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
ctx->tensor_extras.push_back(extra);
for (int id = 0; id < g_device_count; ++id) {
int64_t row_low, row_high;
get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
int64_t nrows_split = row_high - row_low;
if (nrows_split == 0) {
continue;
}
size_t size = ggml_nbytes_split(tensor, nrows_split);
const size_t original_size = size;
// pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
// FIXME: do not crash if cudaMalloc fails
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
ggml_cuda_set_device(id);
char * buf;
CUDA_CHECK(cudaMalloc(&buf, size));
// set padding to 0 to avoid possible NaN values
if (size > original_size) {
CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
}
extra->data_device[id] = buf;
for (int64_t is = 0; is < MAX_STREAMS; ++is) {
CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
}
}
tensor->backend = GGML_BACKEND_GPU_SPLIT;
tensor->extra = extra;
}
static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
// split tensors must always be set in their entirety at once
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
const int64_t ne0 = tensor->ne[0];
const size_t nb1 = tensor->nb[1];
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
for (int id = 0; id < g_device_count; ++id) {
int64_t row_low, row_high;
get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
int64_t nrows_split = row_high - row_low;
if (nrows_split == 0) {
continue;
}
const size_t offset_split = row_low*nb1;
size_t size = ggml_nbytes_split(tensor, nrows_split);
const size_t original_size = size;
// pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
const char * buf_host = (const char *)data + offset_split;
CUDA_CHECK(cudaMemcpy(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice));
}
}
static struct ggml_backend_buffer_i ggml_cuda_backend_split_buffer_interface = {
/* .get_name = */ ggml_backend_cuda_split_buffer_get_name,
/* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
/* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
/* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
/* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
/* .get_tensor = */ NULL,
/* .cpy_tensor_from = */ NULL,
/* .cpy_tensor_to = */ NULL,
/* .clear = */ NULL,
};
// cuda split buffer type
static const char * ggml_backend_cuda_split_buffer_type_name(ggml_backend_buffer_type_t buft) {
return GGML_CUDA_NAME "_Split";
UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
// since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
// instead, we allocate them for each tensor separately in init_tensor
// however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
// as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context();
return ggml_backend_buffer_init(buft, ggml_cuda_backend_split_buffer_interface, ctx, size);
}
static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 128;
UNUSED(buft);
}
static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, ggml_tensor * tensor) {
ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
size_t total_size = 0;
const int64_t ne0 = tensor->ne[0];
for (int id = 0; id < g_device_count; ++id) {
int64_t row_low, row_high;
get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);
int64_t nrows_split = row_high - row_low;
if (nrows_split == 0) {
continue;
}
total_size += ggml_nbytes_split(tensor, nrows_split);
// pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
if (ne0 % MATRIX_ROW_PADDING != 0) {
total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
}
return total_size;
}
static bool ggml_backend_cuda_split_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
return ggml_backend_is_cuda(backend);
UNUSED(buft);
}
static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return false;
UNUSED(buft);
}
static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = {
/* .get_name = */ ggml_backend_cuda_split_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_cuda_split_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cuda_split_buffer_type_get_alignment,
/* .get_alloc_size = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
/* .supports_backend = */ ggml_backend_cuda_split_buffer_type_supports_backend,
/* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
};
ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split) {
// FIXME: this is not thread safe
static std::map<std::array<float, GGML_CUDA_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });
if (all_zero) {
tensor_split_arr = g_default_tensor_split;
} else {
float split_sum = 0.0f;
for (int i = 0; i < g_device_count; ++i) {
tensor_split_arr[i] = split_sum;
split_sum += tensor_split[i];
}
for (int i = 0; i < g_device_count; ++i) {
tensor_split_arr[i] /= split_sum;
}
}
auto it = buft_map.find(tensor_split_arr);
if (it != buft_map.end()) {
return &it->second;
}
struct ggml_backend_buffer_type buft {
/* .iface = */ ggml_backend_cuda_split_buffer_type_interface,
/* .context = */ new ggml_backend_cuda_split_buffer_type_context{tensor_split_arr},
};
auto result = buft_map.emplace(tensor_split_arr, buft);
return &result.first->second;
}
// host buffer type
static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
return "CUDA_Host";
return GGML_CUDA_NAME "_Host";
UNUSED(buft);
}
static const char * ggml_backend_cuda_host_buffer_name(ggml_backend_buffer_t buffer) {
return "CUDA_Host";
return GGML_CUDA_NAME "_Host";
UNUSED(buffer);
}
@ -9713,14 +9960,14 @@ static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
continue;
#ifndef NDEBUG
assert(node->backend == GGML_BACKEND_GPU);
assert(node->backend == GGML_BACKEND_GPU || node->backend == GGML_BACKEND_GPU_SPLIT);
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
assert(node->extra != nullptr);
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->backend == GGML_BACKEND_GPU);
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
assert(node->src[j]->backend == GGML_BACKEND_GPU || node->src[j]->backend == GGML_BACKEND_GPU_SPLIT);
//assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
assert(node->src[j]->extra != nullptr);
}
}

View file

@ -38,6 +38,8 @@ GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
// split tensor buffer that splits matrices by rows across multiple devices
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);

View file

@ -1238,16 +1238,17 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
GGML_UNUSED(gpu);
}
static ggml_backend_buffer_type_t llama_default_buffer_type_split(int main_gpu, const float * tensor_split) {
static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_gpu, const float * tensor_split) {
ggml_backend_buffer_type_t buft = nullptr;
#ifdef GGML_USE_CUBLAS
// TODO
// buft = ggml_backend_cuda_buffer_type_split(tensor_split);
if (ggml_backend_cuda_get_device_count() > 1) {
buft = ggml_backend_cuda_split_buffer_type(tensor_split);
}
#endif
if (buft == nullptr) {
buft = llama_default_buffer_type_offload(main_gpu);
buft = llama_default_buffer_type_offload(fallback_gpu);
}
return buft;
@ -2357,13 +2358,6 @@ struct llama_model_loader {
throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
}
// FIXME: this should be ggml_backend_cuda_split_buffer_type
//if (backend == GGML_BACKEND_GPU_SPLIT) {
// if (ne.size() == 1) {
// throw std::runtime_error(format("%s: 1-dimensional tensor '%s' cannot be split on the GPU", __func__, name.c_str()));
// }
//}
{
bool is_ok = true;
for (size_t i = 0; i < ne.size(); ++i) {
@ -3148,12 +3142,12 @@ static bool llm_load_tensors(
// TODO: user configurable
enum gpu_split_mode {
CUDA_SPLIT_NONE, // single GPU
CUDA_SPLIT_LAYER, // offload layers to different GPUs
CUDA_SPLIT_ROW // split matrix rows across GPUs
LLAMA_SPLIT_NONE, // single GPU
LLAMA_SPLIT_LAYER, // offload layers to different GPUs
LLAMA_SPLIT_ROW // split matrix rows across GPUs
};
gpu_split_mode split_mode = CUDA_SPLIT_LAYER;
gpu_split_mode split_mode = LLAMA_SPLIT_LAYER;
const int64_t n_layer = hparams.n_layer;
const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0);
@ -3167,7 +3161,7 @@ static bool llm_load_tensors(
}
#ifdef GGML_USE_CUBLAS
if (split_mode == CUDA_SPLIT_LAYER) {
if (split_mode == LLAMA_SPLIT_LAYER) {
// calculate the split points
int device_count = ggml_backend_cuda_get_device_count();
float splits[GGML_CUDA_MAX_DEVICES];
@ -9181,7 +9175,8 @@ struct llama_context * llama_new_context_with_model(
std::vector<ggml_backend_t> backends;
// initialize backends
// TODO: only initialize the backends that are actually used
// FIXME: only initialize the backends that are actually used
// this is important for CUDA split buffers, only the main_gpu backend should be initialized
#ifdef GGML_USE_METAL
if (model->n_gpu_layers > 0) {
ctx->backend = ggml_backend_metal_init();