Implement host pool for matrix_info
Creating a new memory pool on the host to store memory location for matrix_info needed to launch gemm_batch from oneMKL/oneMath. Removing complex support in gemm_batch since it is not used in llama.cpp Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
parent
b4d92a59a2
commit
0afea98ef0
3 changed files with 137 additions and 61 deletions
|
@ -333,8 +333,12 @@ struct ggml_backend_sycl_context {
|
||||||
// pool
|
// pool
|
||||||
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
||||||
|
|
||||||
|
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
|
||||||
|
|
||||||
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
|
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
|
||||||
|
|
||||||
|
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
|
||||||
|
|
||||||
ggml_sycl_pool & pool(int device) {
|
ggml_sycl_pool & pool(int device) {
|
||||||
if (pools[device] == nullptr) {
|
if (pools[device] == nullptr) {
|
||||||
pools[device] = new_pool_for_device(stream(device,0), device);
|
pools[device] = new_pool_for_device(stream(device,0), device);
|
||||||
|
@ -345,6 +349,17 @@ struct ggml_backend_sycl_context {
|
||||||
ggml_sycl_pool & pool() {
|
ggml_sycl_pool & pool() {
|
||||||
return pool(device);
|
return pool(device);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_sycl_pool & host_pool(int device) {
|
||||||
|
if (host_pools[device] == nullptr) {
|
||||||
|
host_pools[device] = new_pool_for_host(stream(device,0), device);
|
||||||
|
}
|
||||||
|
return *host_pools[device];
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_sycl_pool & host_pool() {
|
||||||
|
return host_pool(device);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// common device functions
|
// common device functions
|
||||||
|
|
|
@ -18,9 +18,16 @@
|
||||||
#include <syclcompat/math.hpp>
|
#include <syclcompat/math.hpp>
|
||||||
#include <oneapi/mkl.hpp>
|
#include <oneapi/mkl.hpp>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "ggml-sycl.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
#include "ggml-backend-impl.h"
|
||||||
|
#include "ggml-alloc.h"
|
||||||
|
#include "ggml-impl.h"
|
||||||
|
|
||||||
#if defined(__linux__)
|
#if defined(__linux__)
|
||||||
#include <sys/mman.h>
|
#include <sys/mman.h>
|
||||||
#elif defined(_WIN64)
|
#elif defined(_WIN64)
|
||||||
|
@ -82,6 +89,16 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
||||||
return device_type.str();
|
return device_type.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename Ts>
|
||||||
|
struct matrix_info_t
|
||||||
|
{
|
||||||
|
oneapi::mkl::transpose transpose_info[2];
|
||||||
|
Ts value_info[2];
|
||||||
|
std::int64_t size_info[3];
|
||||||
|
std::int64_t ld_info[3];
|
||||||
|
std::int64_t groupsize_info;
|
||||||
|
};
|
||||||
|
|
||||||
namespace dpct
|
namespace dpct
|
||||||
{
|
{
|
||||||
typedef sycl::queue *queue_ptr;
|
typedef sycl::queue *queue_ptr;
|
||||||
|
@ -1731,22 +1748,12 @@ namespace dpct
|
||||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
||||||
const void *alpha, const void **a, int lda,
|
const void *alpha, const void **a, int lda,
|
||||||
const void **b, int ldb, const void *beta, void **c,
|
const void **b, int ldb, const void *beta, void **c,
|
||||||
int ldc, int batch_size)
|
int ldc, int batch_size, matrix_info_t<float>* matrix_info)
|
||||||
{
|
{
|
||||||
struct matrix_info_t
|
|
||||||
{
|
|
||||||
oneapi::mkl::transpose transpose_info[2];
|
|
||||||
Ts value_info[2];
|
|
||||||
std::int64_t size_info[3];
|
|
||||||
std::int64_t ld_info[3];
|
|
||||||
std::int64_t groupsize_info;
|
|
||||||
};
|
|
||||||
|
|
||||||
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
||||||
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
||||||
|
|
||||||
matrix_info_t *matrix_info =
|
|
||||||
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
|
|
||||||
matrix_info->transpose_info[0] = a_trans;
|
matrix_info->transpose_info[0] = a_trans;
|
||||||
matrix_info->transpose_info[1] = b_trans;
|
matrix_info->transpose_info[1] = b_trans;
|
||||||
matrix_info->value_info[0] = alpha_value;
|
matrix_info->value_info[0] = alpha_value;
|
||||||
|
@ -1763,23 +1770,19 @@ namespace dpct
|
||||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
||||||
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
||||||
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
|
matrix_info->size_info + 2, reinterpret_cast<Ts*>(matrix_info->value_info), reinterpret_cast<const Ta **>(a),
|
||||||
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
||||||
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
reinterpret_cast<Ts*>(matrix_info->value_info+1), reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
||||||
&(matrix_info->groupsize_info));
|
&(matrix_info->groupsize_info));
|
||||||
#else
|
#else
|
||||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
||||||
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts*>(matrix_info->value_info),
|
||||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
||||||
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
matrix_info->ld_info + 1, reinterpret_cast<Ts*>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
|
||||||
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
q.submit([&](sycl::handler &cgh)
|
|
||||||
{
|
|
||||||
cgh.depends_on(e);
|
|
||||||
cgh.host_task([=] { std::free(matrix_info); }); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta, class Tb, class Tc, class Ts>
|
template <class Ta, class Tb, class Tc, class Ts>
|
||||||
|
@ -2428,19 +2431,9 @@ namespace dpct
|
||||||
library_data_t a_type, int lda, const void *b[],
|
library_data_t a_type, int lda, const void *b[],
|
||||||
library_data_t b_type, int ldb, const void *beta,
|
library_data_t b_type, int ldb, const void *beta,
|
||||||
void *c[], library_data_t c_type, int ldc,
|
void *c[], library_data_t c_type, int ldc,
|
||||||
int batch_size, library_data_t scaling_type)
|
int batch_size, library_data_t scaling_type,
|
||||||
|
matrix_info_t<float>* matrix_info)
|
||||||
{
|
{
|
||||||
if (scaling_type == library_data_t::real_float &&
|
|
||||||
c_type == library_data_t::complex_float)
|
|
||||||
{
|
|
||||||
scaling_type = library_data_t::complex_float;
|
|
||||||
}
|
|
||||||
else if (scaling_type == library_data_t::real_double &&
|
|
||||||
c_type == library_data_t::complex_double)
|
|
||||||
{
|
|
||||||
scaling_type = library_data_t::complex_double;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::uint64_t key =
|
std::uint64_t key =
|
||||||
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
||||||
switch (key)
|
switch (key)
|
||||||
|
@ -2451,7 +2444,7 @@ namespace dpct
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<float, float, float, float>(
|
detail::gemm_batch_impl<float, float, float, float>(
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||||
batch_size);
|
batch_size, matrix_info);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2460,27 +2453,7 @@ namespace dpct
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<double, double, double, double>(
|
detail::gemm_batch_impl<double, double, double, double>(
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||||
batch_size);
|
batch_size, matrix_info);
|
||||||
break;
|
|
||||||
}
|
|
||||||
case detail::get_type_combination_id(
|
|
||||||
library_data_t::complex_float, library_data_t::complex_float,
|
|
||||||
library_data_t::complex_float, library_data_t::complex_float):
|
|
||||||
{
|
|
||||||
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
|
|
||||||
std::complex<float>, std::complex<float>>(
|
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
batch_size);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case detail::get_type_combination_id(
|
|
||||||
library_data_t::complex_double, library_data_t::complex_double,
|
|
||||||
library_data_t::complex_double, library_data_t::complex_double):
|
|
||||||
{
|
|
||||||
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
|
|
||||||
std::complex<double>, std::complex<double>>(
|
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2490,7 +2463,7 @@ namespace dpct
|
||||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
|
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
|
||||||
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
|
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
|
||||||
a, lda, b, ldb, beta, c, ldc,
|
a, lda, b, ldb, beta, c, ldc,
|
||||||
batch_size);
|
batch_size, matrix_info);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#ifdef __INTEL_MKL__
|
#ifdef __INTEL_MKL__
|
||||||
|
@ -2501,7 +2474,7 @@ namespace dpct
|
||||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
||||||
oneapi::mkl::bfloat16, float>(
|
oneapi::mkl::bfloat16, float>(
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||||
batch_size);
|
batch_size, matrix_info);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2510,7 +2483,7 @@ namespace dpct
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
|
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
|
||||||
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
|
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
|
||||||
b, ldb, beta, c, ldc, batch_size);
|
b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -2525,7 +2498,7 @@ namespace dpct
|
||||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
|
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
|
||||||
float>(q, a_trans, b_trans, m, n, k, &alpha_float,
|
float>(q, a_trans, b_trans, m, n, k, &alpha_float,
|
||||||
a, lda, b, ldb, &beta_float, c, ldc,
|
a, lda, b, ldb, &beta_float, c, ldc,
|
||||||
batch_size);
|
batch_size, matrix_info);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2534,7 +2507,7 @@ namespace dpct
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||||
batch_size);
|
batch_size, matrix_info);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2543,7 +2516,7 @@ namespace dpct
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||||
batch_size);
|
batch_size, matrix_info);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2558,7 +2531,7 @@ namespace dpct
|
||||||
sycl::half beta_half(beta_value);
|
sycl::half beta_half(beta_value);
|
||||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||||
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
|
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
|
||||||
batch_size);
|
batch_size, matrix_info);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -37,6 +37,7 @@
|
||||||
#include "ggml-backend-impl.h"
|
#include "ggml-backend-impl.h"
|
||||||
|
|
||||||
#include "ggml-sycl/backend.hpp"
|
#include "ggml-sycl/backend.hpp"
|
||||||
|
#include "ggml-sycl/dpct/helper.hpp"
|
||||||
#include "ggml-sycl/presets.hpp"
|
#include "ggml-sycl/presets.hpp"
|
||||||
#include "ggml-sycl/gemm.hpp"
|
#include "ggml-sycl/gemm.hpp"
|
||||||
|
|
||||||
|
@ -1173,6 +1174,92 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ggml_sycl_pool_host : public ggml_sycl_pool {
|
||||||
|
|
||||||
|
int device;
|
||||||
|
queue_ptr qptr;
|
||||||
|
|
||||||
|
inline static int counter{0};
|
||||||
|
struct ggml_sycl_buffer {
|
||||||
|
void * ptr = nullptr;
|
||||||
|
size_t size = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set arbitrarly to 64
|
||||||
|
static constexpr int MAX_POOL_SIZE{64};
|
||||||
|
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
|
||||||
|
size_t pool_size = 0;
|
||||||
|
|
||||||
|
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) :
|
||||||
|
qptr(qptr_),
|
||||||
|
device(device_) {
|
||||||
|
}
|
||||||
|
|
||||||
|
~ggml_sycl_pool_host() {
|
||||||
|
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
|
||||||
|
ggml_sycl_buffer & b = buffer_pool[i];
|
||||||
|
if (b.ptr != nullptr) {
|
||||||
|
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
|
||||||
|
b.ptr = nullptr;
|
||||||
|
pool_size -= b.size;
|
||||||
|
b.size = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
counter = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void * alloc(size_t size, size_t * actual_size) override {
|
||||||
|
if ( counter == MAX_POOL_SIZE){
|
||||||
|
ggml_sycl_buffer b = buffer_pool[0];
|
||||||
|
size_t look_ahead_size = (size_t) (1.05 * size);
|
||||||
|
void *ptr = b.ptr;
|
||||||
|
*actual_size = b.size;
|
||||||
|
counter = 1;
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
ggml_sycl_buffer& b = buffer_pool[counter];
|
||||||
|
|
||||||
|
if (b.ptr == nullptr) {
|
||||||
|
void * ptr;
|
||||||
|
|
||||||
|
SYCL_CHECK(
|
||||||
|
CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_host(
|
||||||
|
size, *qptr)));
|
||||||
|
if (!ptr) {
|
||||||
|
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
pool_size += size;
|
||||||
|
*actual_size = size;
|
||||||
|
counter = counter + 1;
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
else if (b.ptr != nullptr) {
|
||||||
|
++counter;
|
||||||
|
b.size = size;
|
||||||
|
return b.ptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void free(void * ptr, size_t size) override {
|
||||||
|
// if the pool is not completed add the pointer to it in place of the first nullptr found.
|
||||||
|
// Otherwise do nothing, pointers will be freed once the pool is deallocated.
|
||||||
|
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
|
||||||
|
ggml_sycl_buffer& b = buffer_pool[i];
|
||||||
|
if (b.ptr == nullptr) {
|
||||||
|
b.ptr = ptr;
|
||||||
|
b.size = size;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
|
||||||
|
// return pool for the host to speed up memory management
|
||||||
|
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
|
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
|
||||||
// TBD: NO VMM support
|
// TBD: NO VMM support
|
||||||
// if (ggml_sycl_info().devices[device].vmm) {
|
// if (ggml_sycl_info().devices[device].vmm) {
|
||||||
|
@ -3363,6 +3450,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
||||||
|
|
||||||
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
||||||
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
||||||
|
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(),1);
|
||||||
|
|
||||||
sycl::range<3> block_dims(1, ne12, ne13);
|
sycl::range<3> block_dims(1, ne12, ne13);
|
||||||
/*
|
/*
|
||||||
|
@ -3398,7 +3486,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
||||||
(const void **)(ptrs_src.get() + 1 * ne23),
|
(const void **)(ptrs_src.get() + 1 * ne23),
|
||||||
dpct::library_data_t::real_half, nb11 / nb10, beta,
|
dpct::library_data_t::real_half, nb11 / nb10, beta,
|
||||||
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
|
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
|
||||||
cu_compute_type)));
|
cu_compute_type, (matrix_info_t<float>*)matrix_info.get())));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (sycl::exception const &exc) {
|
catch (sycl::exception const &exc) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue