cuda uma test
This commit is contained in:
parent
cd93a28cb1
commit
518b75260b
3 changed files with 68 additions and 7 deletions
53
ggml-cuda.cu
53
ggml-cuda.cu
|
@ -407,6 +407,7 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
|
||||||
|
|
||||||
struct ggml_backend_cuda_buffer_context {
|
struct ggml_backend_cuda_buffer_context {
|
||||||
int device;
|
int device;
|
||||||
|
void * host_ptr = nullptr;
|
||||||
void * dev_ptr = nullptr;
|
void * dev_ptr = nullptr;
|
||||||
std::string name;
|
std::string name;
|
||||||
|
|
||||||
|
@ -436,7 +437,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t
|
||||||
|
|
||||||
GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
|
GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
||||||
return ctx->dev_ptr;
|
return ctx->host_ptr ? ctx->host_ptr : ctx->dev_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||||
|
@ -447,7 +448,12 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_is_quantized(tensor->type)) {
|
if (ctx->host_ptr) {
|
||||||
|
size_t offset = (size_t)((uint8_t*)tensor->data - (uint8_t*)ctx->host_ptr);
|
||||||
|
tensor->data = (uint8_t*)ctx->dev_ptr + offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ggml_is_quantized(tensor->type) && !ctx->host_ptr) {
|
||||||
// initialize padding to 0 to avoid possible NaN values
|
// initialize padding to 0 to avoid possible NaN values
|
||||||
size_t original_size = ggml_nbytes(tensor);
|
size_t original_size = ggml_nbytes(tensor);
|
||||||
size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
|
size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
|
||||||
|
@ -560,11 +566,11 @@ GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backen
|
||||||
size_t size = ggml_nbytes(tensor);
|
size_t size = ggml_nbytes(tensor);
|
||||||
int64_t ne0 = tensor->ne[0];
|
int64_t ne0 = tensor->ne[0];
|
||||||
|
|
||||||
if (ggml_is_quantized(tensor->type)) {
|
//if (ggml_is_quantized(tensor->type)) {
|
||||||
if (ne0 % MATRIX_ROW_PADDING != 0) {
|
// if (ne0 % MATRIX_ROW_PADDING != 0) {
|
||||||
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
|
// size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
|
||||||
}
|
// }
|
||||||
}
|
//}
|
||||||
|
|
||||||
return size;
|
return size;
|
||||||
|
|
||||||
|
@ -3082,3 +3088,36 @@ GGML_CALL int ggml_backend_cuda_reg_devices() {
|
||||||
}
|
}
|
||||||
return device_count;
|
return device_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_CALL ggml_backend_buffer_t ggml_backend_cuda_buffer_from_ptr(int device, void * ptr, size_t size) {
|
||||||
|
ggml_backend_buffer_type_t buft = ggml_backend_cuda_buffer_type(device);
|
||||||
|
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
|
||||||
|
|
||||||
|
ggml_cuda_set_device(buft_ctx->device);
|
||||||
|
|
||||||
|
//const size_t page_size = 4096;
|
||||||
|
//ptr = (void *)((uintptr_t)ptr & ~(page_size - 1));
|
||||||
|
|
||||||
|
cudaError_t err = cudaHostRegister(ptr, size, cudaHostRegisterMapped | cudaHostRegisterReadOnly);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
// clear the error
|
||||||
|
cudaGetLastError();
|
||||||
|
GGML_CUDA_LOG_ERROR("%s: registering %.2f MiB on device %d: cudaHostRegister failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void * dev_ptr;
|
||||||
|
err = cudaHostGetDevicePointer(&dev_ptr, ptr, 0);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
// clear the error
|
||||||
|
cudaGetLastError();
|
||||||
|
GGML_CUDA_LOG_ERROR("%s: failed to get device pointer: %s\n", __func__, cudaGetErrorString(err));
|
||||||
|
cudaHostUnregister(ptr);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
|
||||||
|
ctx->host_ptr = ptr;
|
||||||
|
|
||||||
|
return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
|
||||||
|
}
|
||||||
|
|
|
@ -31,6 +31,8 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_typ
|
||||||
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
||||||
|
|
||||||
|
GGML_CALL ggml_backend_buffer_t ggml_backend_cuda_buffer_from_ptr(int device, void * ptr, size_t size);
|
||||||
|
|
||||||
GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
|
GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
|
||||||
GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
|
GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
|
||||||
GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
|
GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
|
||||||
|
|
20
llama.cpp
20
llama.cpp
|
@ -6108,6 +6108,26 @@ static bool llm_load_tensors(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUDA
|
||||||
|
else if (ml.use_mmap && use_mmap_buffer && buft == ggml_backend_cuda_buffer_type(0)) {
|
||||||
|
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
||||||
|
void * addr = nullptr;
|
||||||
|
size_t first, last;
|
||||||
|
ml.get_mapping_range(&first, &last, &addr, idx, ctx);
|
||||||
|
if (first >= last) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ggml_backend_buffer_t buf = ggml_backend_cuda_buffer_from_ptr(0, (char *) addr + first, last - first);
|
||||||
|
if (buf == nullptr) {
|
||||||
|
throw std::runtime_error("unable to allocate backend CUDA buffer");
|
||||||
|
}
|
||||||
|
model.bufs.push_back(buf);
|
||||||
|
bufs.emplace(idx, buf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
else {
|
else {
|
||||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||||
if (buf == nullptr) {
|
if (buf == nullptr) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue