cuda : rename build flag to LLAMA_CUDA (#6299)
This commit is contained in:
parent
b06c16ef9f
commit
280345968d
28 changed files with 129 additions and 115 deletions
26
llama.cpp
26
llama.cpp
|
@ -7,7 +7,7 @@
|
|||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#ifdef GGML_USE_CUDA
|
||||
# include "ggml-cuda.h"
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
# include "ggml-opencl.h"
|
||||
|
@ -1505,7 +1505,7 @@ static std::string llama_token_to_piece(const struct llama_context * ctx, llama_
|
|||
static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
|
||||
ggml_backend_buffer_type_t buft = nullptr;
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
#if defined(GGML_USE_CUDA)
|
||||
// host buffers should only be used when data is expected to be copied to/from the GPU
|
||||
if (host_buffer) {
|
||||
buft = ggml_backend_cuda_host_buffer_type();
|
||||
|
@ -1535,7 +1535,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
|
|||
|
||||
#ifdef GGML_USE_METAL
|
||||
buft = ggml_backend_metal_buffer_type();
|
||||
#elif defined(GGML_USE_CUBLAS)
|
||||
#elif defined(GGML_USE_CUDA)
|
||||
buft = ggml_backend_cuda_buffer_type(gpu);
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
buft = ggml_backend_vk_buffer_type(gpu);
|
||||
|
@ -1561,7 +1561,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
|
|||
static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_gpu, const float * tensor_split) {
|
||||
ggml_backend_buffer_type_t buft = nullptr;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (ggml_backend_cuda_get_device_count() > 1) {
|
||||
buft = ggml_backend_cuda_split_buffer_type(tensor_split);
|
||||
}
|
||||
|
@ -1582,7 +1582,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_g
|
|||
}
|
||||
|
||||
static size_t llama_get_device_count() {
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
#if defined(GGML_USE_CUDA)
|
||||
return ggml_backend_cuda_get_device_count();
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
return ggml_backend_sycl_get_device_count();
|
||||
|
@ -1594,7 +1594,7 @@ static size_t llama_get_device_count() {
|
|||
}
|
||||
|
||||
static size_t llama_get_device_memory(int device) {
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
#if defined(GGML_USE_CUDA)
|
||||
size_t total;
|
||||
size_t free;
|
||||
ggml_backend_cuda_get_device_memory(device, &total, &free);
|
||||
|
@ -2080,7 +2080,7 @@ struct llama_model {
|
|||
ggml_free(ctx);
|
||||
}
|
||||
for (ggml_backend_buffer_t buf : bufs) {
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (ggml_backend_buffer_get_type(buf) == ggml_backend_cpu_buffer_type()) {
|
||||
ggml_backend_cuda_unregister_host_buffer(ggml_backend_buffer_get_base(buf));
|
||||
}
|
||||
|
@ -5269,7 +5269,7 @@ static bool llm_load_tensors(
|
|||
}
|
||||
model.bufs.push_back(buf);
|
||||
bufs.emplace(idx, buf);
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (n_layer >= n_gpu_layers) {
|
||||
ggml_backend_cuda_register_host_buffer(
|
||||
ggml_backend_buffer_get_base(buf),
|
||||
|
@ -13371,7 +13371,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
|
|||
size_t llama_max_devices(void) {
|
||||
#if defined(GGML_USE_METAL)
|
||||
return 1;
|
||||
#elif defined(GGML_USE_CUBLAS)
|
||||
#elif defined(GGML_USE_CUDA)
|
||||
return GGML_CUDA_MAX_DEVICES;
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
return GGML_SYCL_MAX_DEVICES;
|
||||
|
@ -13391,8 +13391,8 @@ bool llama_supports_mlock(void) {
|
|||
}
|
||||
|
||||
bool llama_supports_gpu_offload(void) {
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
|
||||
defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
|
||||
#if defined(GGML_USE_CUDA) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
|
||||
defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
|
||||
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
||||
return true;
|
||||
#else
|
||||
|
@ -13597,7 +13597,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
ctx->backends.push_back(ctx->backend_metal);
|
||||
}
|
||||
#elif defined(GGML_USE_CUBLAS)
|
||||
#elif defined(GGML_USE_CUDA)
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
|
||||
ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu);
|
||||
|
@ -13744,7 +13744,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
|
||||
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
||||
bool pipeline_parallel = llama_get_device_count() > 1 && model->n_gpu_layers > (int)model->hparams.n_layer && model->split_mode == LLAMA_SPLIT_MODE_LAYER;
|
||||
#ifndef GGML_USE_CUBLAS
|
||||
#ifndef GGML_USE_CUDA
|
||||
// pipeline parallelism requires support for async compute and events
|
||||
// currently this is only implemented in the CUDA backend
|
||||
pipeline_parallel = false;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue