From 86869fbdab7f81c0569ae02788a587e17f167bde Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Thu, 13 Jun 2024 14:32:03 +0200 Subject: [PATCH] Change assertions to exceptions in llama_file, find correct cuda backend to create CUDA resources and respect the use_mmap flag again for CUDA. --- llama.cpp | 117 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 77 insertions(+), 40 deletions(-) diff --git a/llama.cpp b/llama.cpp index ac4582864..a6eb79c99 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1288,7 +1288,6 @@ private: DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); if (!bufLen) { - std::stringstream ss; ret = format("Win32 error code: %s", error_code); } else { ret = lpMsgBuf; @@ -1316,8 +1315,9 @@ public: LARGE_INTEGER li; li.QuadPart = 0; BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT); - - GGML_ASSERT(ret); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()))); + } return li.QuadPart; } @@ -1332,11 +1332,15 @@ public: LARGE_INTEGER li; li.QuadPart = offset; BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence); - - GGML_ASSERT(ret); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()))); + } } void read_raw(void * ptr, size_t len) const { + // On Win32 ReadFile is significant faster than fread which is again significant faster than std::fstream. Thus + // use the Win32 API to do file io instead of the C/C++ library functions. + // There are conditions under which ReadFile cannot read chunks >64MB. // Thus split the operation into smaller chunks if len exceeds this limit. size_t bytes_read = 0; @@ -1410,7 +1414,10 @@ public: #else long ret = std::ftell(fp); #endif - GGML_ASSERT(ret != -1); // this really shouldn't fail + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + return (size_t) ret; } @@ -1420,7 +1427,9 @@ public: #else int ret = std::fseek(fp, (long) offset, whence); #endif - GGML_ASSERT(ret == 0); // same + if (ret != 0) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } } void read_raw(void * ptr, size_t len) const { @@ -3831,19 +3840,40 @@ struct llama_model_loader { std::vector>> validation_result; #if defined(GGML_USE_CUDA) + // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. + // NVMe raid configurations might require more / larger buffers. + constexpr size_t num_buffers = 4; + constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + std::vector host_buffers; std::vector host_ptrs; std::vector events; size_t buffer_idx = 0; // buffer to use for async loads - ggml_backend_t backend = ggml_backend_cuda_init(0); // TODO how to get the CUDA device / backend here? + ggml_backend_t cuda_backend = nullptr; + if (!use_mmap) { + // When not using mmaped io use async uploads from pinned memory to GPU memory. + // First determine if the CUDA backend is active, and if so, determine the device ID. + ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr; + if (buf) { + ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf); + for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) { + auto cuda_buffer_type = ggml_backend_cuda_buffer_type(i); + if (buffer_type == ggml_backend_cuda_buffer_type(i)) { + cuda_backend = ggml_backend_cuda_init(i); + break; + } + } + } - constexpr size_t num_buffers = 4; - constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB - for (size_t idx = 0; idx < num_buffers; ++idx) { - host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size)); - host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx])); - events.emplace_back(ggml_backend_event_new(backend)); + // If the cuda backend is active create pinned memory buffers and events for synchronisation. + if (cuda_backend) { + for (size_t idx = 0; idx < num_buffers; ++idx) { + host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size)); + host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx])); + events.emplace_back(ggml_backend_event_new(cuda_backend)); + } + } } #endif @@ -3903,32 +3933,37 @@ struct llama_model_loader { } } else { #if defined(GGML_USE_CUDA) - file->seek(weight->offs, SEEK_SET); + // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (cuda_backend) { + file->seek(weight->offs, SEEK_SET); - size_t bytes_read = 0; + size_t bytes_read = 0; - while (bytes_read < n_size) - { - size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + while (bytes_read < n_size) + { + size_t read_iteration = std::min(buffer_size, n_size - bytes_read); - ggml_backend_event_synchronize(events[buffer_idx]); - file->read_raw(host_ptrs[buffer_idx], read_iteration); - ggml_backend_tensor_set_async(backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); - ggml_backend_event_record(events[buffer_idx]); + ggml_backend_event_synchronize(events[buffer_idx]); + file->read_raw(host_ptrs[buffer_idx], read_iteration); + ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + ggml_backend_event_record(events[buffer_idx]); - bytes_read += read_iteration; - ++buffer_idx; - buffer_idx %= num_buffers; - } -#else - read_buf.resize(n_size); - file->seek(weight->offs, SEEK_SET); - file->read_raw(read_buf.data(), n_size); - ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); - if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { - throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + bytes_read += read_iteration; + ++buffer_idx; + buffer_idx %= num_buffers; + } } + else #endif + { + read_buf.resize(n_size); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); + ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } + } } } @@ -3936,12 +3971,14 @@ struct llama_model_loader { } #if defined(GGML_USE_CUDA) - for (size_t idx = 0; idx < num_buffers;++idx) { - ggml_backend_event_synchronize(events[idx]); - ggml_backend_event_free(events[idx]); - ggml_backend_buffer_free(host_buffers[idx]); - - //ggml_backend_free(backend); + // free temporary resources used for async cuda uploads + if (cuda_backend) { + for (size_t idx = 0; idx < num_buffers;++idx) { + ggml_backend_event_synchronize(events[idx]); + ggml_backend_event_free(events[idx]); + ggml_backend_buffer_free(host_buffers[idx]); + } + ggml_backend_free(cuda_backend); } #endif