llama : initial ggml-backend integration (#4520)
* llama : initial ggml-backend integration * add ggml-metal * cuda backend can be used though ggml-backend with LLAMA_GGML_BACKEND_CUDA_TEST access all tensor data with ggml_backend_tensor_get/set * add ggml_backend_buffer_clear zero-init KV cache buffer * add ggml_backend_buffer_is_hos, used to avoid copies if possible when accesing tensor data * disable gpu backends with ngl 0 * more accurate mlock * unmap offloaded part of the model * use posix_fadvise64(.., POSIX_FADV_SEQUENTIAL) to improve performance with mmap * update quantize and lora * update session copy/set to use ggml-backend ggml-ci * use posix_fadvise instead of posix_fadvise64 * ggml_backend_alloc_ctx_tensors_from_buft : remove old print * llama_mmap::align_offset : use pointers instead of references for out parameters * restore progress_callback behavior * move final progress_callback call to load_all_data * cuda : fix fprintf format string (minor) * do not offload scales * llama_mmap : avoid unmapping the same fragments again in the destructor * remove unnecessary unmap * metal : add default log function that prints to stderr, cleanup code ggml-ci --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
31f27758fa
commit
d232aca5a7
11 changed files with 926 additions and 752 deletions
89
ggml-cuda.cu
89
ggml-cuda.cu
|
@ -9081,7 +9081,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
|||
|
||||
char * buf;
|
||||
CUDA_CHECK(cudaMalloc(&buf, size));
|
||||
char * buf_host = (char*)data + offset_split;
|
||||
char * buf_host = (char *)data + offset_split;
|
||||
|
||||
// set padding to 0 to avoid possible NaN values
|
||||
if (size > original_size) {
|
||||
|
@ -9226,11 +9226,10 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
|
|||
|
||||
ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
|
||||
|
||||
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
|
||||
tensor->op == GGML_OP_VIEW;
|
||||
const bool inplace = tensor->view_src != nullptr;
|
||||
|
||||
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
|
||||
if (inplace && (tensor->view_src->backend == GGML_BACKEND_GPU || tensor->view_src->backend == GGML_BACKEND_GPU_SPLIT)) {
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->view_src->extra;
|
||||
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
||||
size_t view_offset = 0;
|
||||
if (tensor->op == GGML_OP_VIEW) {
|
||||
|
@ -9317,7 +9316,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|||
if (tensor->op == GGML_OP_MUL_MAT) {
|
||||
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
|
||||
#ifndef NDEBUG
|
||||
fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = " PRId64 ", src1->ne[3] = " PRId64 " - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
|
||||
fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
@ -9523,7 +9522,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
|
|||
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
|
||||
|
||||
if (tensor->view_src != NULL && tensor->view_offs == 0) {
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft); // TODO
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
tensor->backend = tensor->view_src->backend;
|
||||
tensor->extra = tensor->view_src->extra;
|
||||
return;
|
||||
|
@ -9554,23 +9553,34 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
|
|||
}
|
||||
|
||||
static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||
|
||||
CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
|
||||
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
|
||||
|
||||
UNUSED(buffer);
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost));
|
||||
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
|
||||
|
||||
UNUSED(buffer);
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
|
||||
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
|
||||
}
|
||||
|
||||
static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
|
||||
|
@ -9581,6 +9591,7 @@ static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
|
|||
/* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
|
||||
/* .cpy_tensor_from = */ NULL,
|
||||
/* .cpy_tensor_to = */ NULL,
|
||||
/* .clear = */ ggml_backend_cuda_buffer_clear,
|
||||
};
|
||||
|
||||
// cuda buffer type
|
||||
|
@ -9632,35 +9643,36 @@ static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_t
|
|||
UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_i cuda_backend_buffer_type_interface = {
|
||||
static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
|
||||
/* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment,
|
||||
/* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
|
||||
/* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend,
|
||||
/* .is_host = */ nullptr,
|
||||
};
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
|
||||
static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda[GGML_CUDA_MAX_DEVICES];
|
||||
static bool ggml_backend_buffer_type_cuda_initialized = false;
|
||||
if (!ggml_backend_buffer_type_cuda_initialized) {
|
||||
static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
|
||||
|
||||
static bool ggml_backend_cuda_buffer_type_initialized = false;
|
||||
|
||||
if (!ggml_backend_cuda_buffer_type_initialized) {
|
||||
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) {
|
||||
ggml_backend_buffer_type_cuda[i] = {
|
||||
/* .iface = */ cuda_backend_buffer_type_interface,
|
||||
ggml_backend_cuda_buffer_types[i] = {
|
||||
/* .iface = */ ggml_backend_cuda_buffer_type_interface,
|
||||
/* .context = */ (ggml_backend_buffer_type_context_t) (intptr_t) i,
|
||||
};
|
||||
}
|
||||
ggml_backend_buffer_type_cuda_initialized = true;
|
||||
ggml_backend_cuda_buffer_type_initialized = true;
|
||||
}
|
||||
|
||||
return &ggml_backend_buffer_type_cuda[device];
|
||||
return &ggml_backend_cuda_buffer_types[device];
|
||||
}
|
||||
|
||||
// host buffer type
|
||||
|
||||
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
|
||||
CUDA_CHECK(cudaFreeHost(ctx->dev_ptr));
|
||||
delete ctx;
|
||||
CUDA_CHECK(cudaFreeHost(buffer->context));
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
|
@ -9673,24 +9685,21 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm
|
|||
buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
|
||||
|
||||
return buffer;
|
||||
|
||||
UNUSED(buft);
|
||||
}
|
||||
|
||||
struct ggml_backend_buffer_type_i cuda_backend_host_buffer_type_interface = {
|
||||
/* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
|
||||
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
||||
/* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
|
||||
};
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
|
||||
static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda_host = {
|
||||
/* .iface = */ cuda_backend_host_buffer_type_interface,
|
||||
static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = {
|
||||
/* .iface = */ {
|
||||
/* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
|
||||
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
||||
/* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
|
||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||
},
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
return &ggml_backend_buffer_type_cuda_host;
|
||||
return &ggml_backend_cuda_buffer_type_host;
|
||||
}
|
||||
|
||||
// backend
|
||||
|
@ -9722,8 +9731,6 @@ static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tens
|
|||
ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
|
||||
|
||||
GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0]));
|
||||
|
@ -9733,8 +9740,6 @@ static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggm
|
|||
ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
|
||||
|
||||
GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0]));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue