diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index e9500f3a1..91b432da9 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -333,8 +333,12 @@ struct ggml_backend_sycl_context { // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; + std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; + static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); + static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); + ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { pools[device] = new_pool_for_device(stream(device,0), device); @@ -345,6 +349,17 @@ struct ggml_backend_sycl_context { ggml_sycl_pool & pool() { return pool(device); } + + ggml_sycl_pool & host_pool(int device) { + if (host_pools[device] == nullptr) { + host_pools[device] = new_pool_for_host(stream(device,0), device); + } + return *host_pools[device]; + } + + ggml_sycl_pool & host_pool() { + return host_pool(device); + } }; // common device functions diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index e167948e7..645f681d8 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -18,9 +18,16 @@ #include #include #include +#include +#include "ggml-sycl.h" #include "ggml.h" +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" +#include "ggml-impl.h" + #if defined(__linux__) #include #elif defined(_WIN64) @@ -82,6 +89,16 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { return device_type.str(); } +template +struct matrix_info_t +{ + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; +}; + namespace dpct { typedef sycl::queue *queue_ptr; @@ -1731,22 +1748,12 @@ namespace dpct oneapi::mkl::transpose b_trans, int m, int n, int k, const void *alpha, const void **a, int lda, const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) + int ldc, int batch_size, matrix_info_t* matrix_info) { - struct matrix_info_t - { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); matrix_info->transpose_info[0] = a_trans; matrix_info->transpose_info[1] = b_trans; matrix_info->value_info[0] = alpha_value; @@ -1763,23 +1770,19 @@ namespace dpct sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast(a), + matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, + reinterpret_cast(matrix_info->value_info+1), reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #else sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, - matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, + matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), - matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast(c), + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #endif - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); } template @@ -2428,19 +2431,9 @@ namespace dpct library_data_t a_type, int lda, const void *b[], library_data_t b_type, int ldb, const void *beta, void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) + int batch_size, library_data_t scaling_type, + matrix_info_t* matrix_info) { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - std::uint64_t key = detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) @@ -2451,7 +2444,7 @@ namespace dpct { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2460,27 +2453,7 @@ namespace dpct { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2490,7 +2463,7 @@ namespace dpct detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } #ifdef __INTEL_MKL__ @@ -2501,7 +2474,7 @@ namespace dpct detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2510,7 +2483,7 @@ namespace dpct { detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); + b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #endif @@ -2525,7 +2498,7 @@ namespace dpct detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2534,7 +2507,7 @@ namespace dpct { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2543,7 +2516,7 @@ namespace dpct { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2558,7 +2531,7 @@ namespace dpct sycl::half beta_half(beta_value); detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); + batch_size, matrix_info); break; } default: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 037c8093e..d89286d89 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -37,6 +37,7 @@ #include "ggml-backend-impl.h" #include "ggml-sycl/backend.hpp" +#include "ggml-sycl/dpct/helper.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" @@ -1173,6 +1174,92 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } }; +struct ggml_sycl_pool_host : public ggml_sycl_pool { + + int device; + queue_ptr qptr; + + inline static int counter{0}; + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + // Set arbitrarly to 64 + static constexpr int MAX_POOL_SIZE{64}; + std::vector buffer_pool = std::vector(MAX_POOL_SIZE); + size_t pool_size = 0; + + explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : + qptr(qptr_), + device(device_) { + } + + ~ggml_sycl_pool_host() { + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + b.ptr = nullptr; + pool_size -= b.size; + b.size = 0; + } + } + counter = 0; + } + + void * alloc(size_t size, size_t * actual_size) override { + if ( counter == MAX_POOL_SIZE){ + ggml_sycl_buffer b = buffer_pool[0]; + size_t look_ahead_size = (size_t) (1.05 * size); + void *ptr = b.ptr; + *actual_size = b.size; + counter = 1; + return ptr; + } + ggml_sycl_buffer& b = buffer_pool[counter]; + + if (b.ptr == nullptr) { + void * ptr; + + SYCL_CHECK( + CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_host( + size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); + return nullptr; + } + pool_size += size; + *actual_size = size; + counter = counter + 1; + return ptr; + } + else if (b.ptr != nullptr) { + ++counter; + b.size = size; + return b.ptr; + } + } + + void free(void * ptr, size_t size) override { + // if the pool is not completed add the pointer to it in place of the first nullptr found. + // Otherwise do nothing, pointers will be freed once the pool is deallocated. + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer& b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + } +}; + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { + // return pool for the host to speed up memory management + return std::unique_ptr(new ggml_sycl_pool_host(qptr, device)); +} + std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { // TBD: NO VMM support // if (ggml_sycl_info().devices[device].vmm) { @@ -3363,6 +3450,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(),1); sycl::range<3> block_dims(1, ne12, ne13); /* @@ -3398,7 +3486,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, (const void **)(ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, - cu_compute_type))); + cu_compute_type, (matrix_info_t*)matrix_info.get()))); } } catch (sycl::exception const &exc) {