rebase with master, support tow new OPs, close feature for -sm=row, fix for unit test
This commit is contained in:
parent
33563a8a52
commit
4c29df303d
3 changed files with 261 additions and 22 deletions
|
@ -640,6 +640,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
} else if (arg_next == "layer") {
|
} else if (arg_next == "layer") {
|
||||||
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
|
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
|
||||||
} else if (arg_next == "row") {
|
} else if (arg_next == "row") {
|
||||||
|
#ifdef GGML_USE_SYCL
|
||||||
|
fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n");
|
||||||
|
exit(1);
|
||||||
|
#endif // GGML_USE_SYCL
|
||||||
params.split_mode = LLAMA_SPLIT_MODE_ROW;
|
params.split_mode = LLAMA_SPLIT_MODE_ROW;
|
||||||
} else {
|
} else {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
|
274
ggml-sycl.cpp
274
ggml-sycl.cpp
|
@ -3217,6 +3217,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||||
#define SYCL_SILU_BLOCK_SIZE 256
|
#define SYCL_SILU_BLOCK_SIZE 256
|
||||||
#define SYCL_TANH_BLOCK_SIZE 256
|
#define SYCL_TANH_BLOCK_SIZE 256
|
||||||
#define SYCL_RELU_BLOCK_SIZE 256
|
#define SYCL_RELU_BLOCK_SIZE 256
|
||||||
|
#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
|
||||||
|
#define SYCL_HARDSWISH_BLOCK_SIZE 256
|
||||||
#define SYCL_SQR_BLOCK_SIZE 256
|
#define SYCL_SQR_BLOCK_SIZE 256
|
||||||
#define SYCL_CPY_BLOCK_SIZE 32
|
#define SYCL_CPY_BLOCK_SIZE 32
|
||||||
#define SYCL_SCALE_BLOCK_SIZE 256
|
#define SYCL_SCALE_BLOCK_SIZE 256
|
||||||
|
@ -3233,6 +3235,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||||
#define SYCL_PAD_BLOCK_SIZE 256
|
#define SYCL_PAD_BLOCK_SIZE 256
|
||||||
#define SYCL_ACC_BLOCK_SIZE 256
|
#define SYCL_ACC_BLOCK_SIZE 256
|
||||||
#define SYCL_IM2COL_BLOCK_SIZE 256
|
#define SYCL_IM2COL_BLOCK_SIZE 256
|
||||||
|
#define SYCL_POOL2D_BLOCK_SIZE 256
|
||||||
|
|
||||||
// dmmv = dequantize_mul_mat_vec
|
// dmmv = dequantize_mul_mat_vec
|
||||||
#ifndef GGML_SYCL_DMMV_X
|
#ifndef GGML_SYCL_DMMV_X
|
||||||
|
@ -3744,6 +3747,28 @@ static void relu_f32(const float * x, float * dst, const int k,
|
||||||
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
|
item_ct1.get_local_id(2);
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hardswish_f32(const float * x, float * dst, const int k,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
|
item_ct1.get_local_id(2);
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||||
|
}
|
||||||
|
|
||||||
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
|
@ -7854,6 +7879,13 @@ static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
||||||
*dsti = *xi;
|
*dsti = *xi;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||||
|
const sycl::half *xi = (const sycl::half *)cxi;
|
||||||
|
float * dsti = (float *) cdsti;
|
||||||
|
|
||||||
|
*dsti = *xi;
|
||||||
|
}
|
||||||
|
|
||||||
static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
|
static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
|
||||||
const int16_t *xi = (const int16_t *)cxi;
|
const int16_t *xi = (const int16_t *)cxi;
|
||||||
int16_t *dsti = (int16_t *)cdsti;
|
int16_t *dsti = (int16_t *)cdsti;
|
||||||
|
@ -8451,6 +8483,62 @@ static void im2col_kernel(const float *x, T *dst, int offset_delta,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Ti, typename To>
|
||||||
|
static void pool2d_nchw_kernel(
|
||||||
|
const int ih, const int iw, const int oh, const int ow,
|
||||||
|
const int kh, const int kw, const int sh, const int sw,
|
||||||
|
const int ph, const int pw, const int parallel_elements,
|
||||||
|
const Ti* src, To* dst, const enum ggml_op_pool op,
|
||||||
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
int idx = item_ct1.get_local_id(2) +
|
||||||
|
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||||
|
if (idx >= parallel_elements) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int I_HW = ih * iw;
|
||||||
|
const int O_HW = oh * ow;
|
||||||
|
const int nc = idx / O_HW;
|
||||||
|
const int cur_oh = idx % O_HW / ow;
|
||||||
|
const int cur_ow = idx % O_HW % ow;
|
||||||
|
const Ti* i_ptr = src + nc * I_HW;
|
||||||
|
To* o_ptr = dst + nc * O_HW;
|
||||||
|
const int start_h = cur_oh * sh - ph;
|
||||||
|
const int bh = sycl::max(0, start_h);
|
||||||
|
const int eh = sycl::min(ih, start_h + kh);
|
||||||
|
const int start_w = cur_ow * sw - pw;
|
||||||
|
const int bw = sycl::max(0, start_w);
|
||||||
|
const int ew = sycl::min(iw, start_w + kw);
|
||||||
|
|
||||||
|
To res = 0;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case GGML_OP_POOL_AVG: res = 0; break;
|
||||||
|
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = bh; i < eh; i += 1) {
|
||||||
|
for (int j = bw; j < ew; j += 1) {
|
||||||
|
#if DPCT_COMPATIBILITY_TEMP >= 350
|
||||||
|
/*
|
||||||
|
DPCT1098:106: The '*' expression is used instead of the __ldg
|
||||||
|
call. These two expressions do not provide the exact same
|
||||||
|
functionality. Check the generated code for potential precision
|
||||||
|
and/or performance issues.
|
||||||
|
*/
|
||||||
|
Ti cur = *(i_ptr + i * iw + j);
|
||||||
|
#else
|
||||||
|
Ti cur = i_ptr[i * iw + j];
|
||||||
|
#endif
|
||||||
|
switch (op) {
|
||||||
|
case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
|
||||||
|
case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o_ptr[cur_oh * ow + cur_ow] = res;
|
||||||
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dq>
|
template <int qk, int qr, dequantize_kernel_t dq>
|
||||||
static void get_rows_sycl(const ggml_tensor *src0, const ggml_tensor *src1,
|
static void get_rows_sycl(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
ggml_tensor *dst, const void *src0_dd,
|
ggml_tensor *dst, const void *src0_dd,
|
||||||
|
@ -8739,6 +8827,30 @@ static void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||||
|
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
hardsigmoid_f32(x, dst, k, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||||
|
dpct::queue_ptr stream) {
|
||||||
|
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||||
|
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
hardswish_f32(x, dst, k, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
const float negative_slope,
|
const float negative_slope,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
|
@ -10593,6 +10705,31 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
|
||||||
|
const int ne01, const int ne02, const int nb00,
|
||||||
|
const int nb01, const int nb02, const int nb03,
|
||||||
|
const int ne10, const int ne11, const int ne12,
|
||||||
|
const int nb10, const int nb11, const int nb12,
|
||||||
|
const int nb13, dpct::queue_ptr stream) {
|
||||||
|
|
||||||
|
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
||||||
|
{
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||||
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
|
||||||
|
nb01, nb02, nb03, ne10, ne11, ne12,
|
||||||
|
nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
|
static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
|
||||||
const int ne00, const int ne01,
|
const int ne00, const int ne01,
|
||||||
const int ne02, const int nb00,
|
const int ne02, const int nb00,
|
||||||
|
@ -11779,7 +11916,6 @@ inline void ggml_sycl_op_tanh(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
|
@ -11802,6 +11938,37 @@ inline void ggml_sycl_op_relu(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
(void) src1_dd;
|
(void) src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_op_hardsigmoid(const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
|
const float *src0_dd, const float *src1_dd,
|
||||||
|
float *dst_dd,
|
||||||
|
const dpct::queue_ptr &main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) dst;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_op_hardswish(const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
|
const float *src0_dd, const float *src1_dd,
|
||||||
|
float *dst_dd, const dpct::queue_ptr &main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) dst;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
inline void ggml_sycl_op_leaky_relu(const ggml_tensor *src0,
|
inline void ggml_sycl_op_leaky_relu(const ggml_tensor *src0,
|
||||||
const ggml_tensor *src1, ggml_tensor *dst,
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
const float *src0_dd, const float *src1_dd,
|
const float *src0_dd, const float *src1_dd,
|
||||||
|
@ -12283,7 +12450,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
|
to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
|
||||||
}
|
}
|
||||||
const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
|
const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
|
||||||
? (const sycl::half *)src1_ddf_i
|
? (const sycl::half *)src1->data + src1_padded_row_size
|
||||||
: src1_as_f16.get();
|
: src1_as_f16.get();
|
||||||
sycl_pool_alloc<sycl::half> dst_f16(row_diff * src1_ncols);
|
sycl_pool_alloc<sycl::half> dst_f16(row_diff * src1_ncols);
|
||||||
|
|
||||||
|
@ -12451,6 +12618,48 @@ inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
(void) src1_dd;
|
(void) src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
|
||||||
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
|
const float *src0_dd, const float *src1_dd,
|
||||||
|
float *dst_dd, const dpct::queue_ptr &main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const int32_t * opts = (const int32_t *)dst->op_params;
|
||||||
|
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
||||||
|
const int k0 = opts[1];
|
||||||
|
const int k1 = opts[2];
|
||||||
|
const int s0 = opts[3];
|
||||||
|
const int s1 = opts[4];
|
||||||
|
const int p0 = opts[5];
|
||||||
|
const int p1 = opts[6];
|
||||||
|
|
||||||
|
const int64_t IH = src0->ne[1];
|
||||||
|
const int64_t IW = src0->ne[0];
|
||||||
|
|
||||||
|
const int64_t N = dst->ne[3];
|
||||||
|
const int64_t OC = dst->ne[2];
|
||||||
|
const int64_t OH = dst->ne[1];
|
||||||
|
const int64_t OW = dst->ne[0];
|
||||||
|
|
||||||
|
const int parallel_elements = N * OC * OH * OW;
|
||||||
|
const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
|
||||||
|
sycl::range<3> block_nums(1, 1, num_blocks);
|
||||||
|
main_stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums *
|
||||||
|
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
|
||||||
|
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
|
||||||
|
parallel_elements, src0_dd, dst_dd, op,
|
||||||
|
item_ct1);
|
||||||
|
});
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
|
inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
|
||||||
const ggml_tensor *src1, ggml_tensor *dst,
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
const float *src0_dd, const float *src1_dd,
|
const float *src0_dd, const float *src1_dd,
|
||||||
|
@ -12796,7 +13005,6 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
|
||||||
|
|
||||||
GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
||||||
GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
|
|
||||||
|
|
||||||
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
|
||||||
|
|
||||||
|
@ -12815,7 +13023,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
|
||||||
const bool src0_is_contiguous = ggml_is_contiguous(src0);
|
const bool src0_is_contiguous = ggml_is_contiguous(src0);
|
||||||
const bool src1_is_contiguous = ggml_is_contiguous(src1);
|
const bool src1_is_contiguous = ggml_is_contiguous(src1);
|
||||||
|
|
||||||
const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
|
int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
|
||||||
|
|
||||||
const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
|
const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
|
||||||
GGML_ASSERT(!(split && ne02 > 1));
|
GGML_ASSERT(!(split && ne02 > 1));
|
||||||
|
@ -13022,7 +13230,9 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
|
||||||
if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
|
if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
|
||||||
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
|
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
|
||||||
}
|
}
|
||||||
|
if (src1->type == GGML_TYPE_F16) {
|
||||||
|
src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
|
||||||
|
}
|
||||||
// do the computation
|
// do the computation
|
||||||
op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
|
op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
|
||||||
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream);
|
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream);
|
||||||
|
@ -13200,6 +13410,18 @@ static void ggml_sycl_relu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
||||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
|
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_hardsigmoid);
|
||||||
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
|
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_hardswish);
|
||||||
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_sycl_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_sycl_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_leaky_relu);
|
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_leaky_relu);
|
||||||
|
@ -13981,6 +14203,8 @@ static void ggml_sycl_cpy(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
|
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
|
||||||
|
@ -14024,6 +14248,10 @@ static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||||
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi);
|
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_sycl_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_sycl_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_im2col);
|
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_im2col);
|
||||||
}
|
}
|
||||||
|
@ -14322,6 +14550,12 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
func = ggml_sycl_relu;
|
func = ggml_sycl_relu;
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
func = ggml_sycl_hardsigmoid;
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
|
func = ggml_sycl_hardswish;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -14396,6 +14630,9 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
func = ggml_sycl_im2col;
|
func = ggml_sycl_im2col;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
func = ggml_sycl_pool2d;
|
||||||
|
break;
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
func = ggml_sycl_sum_rows;
|
func = ggml_sycl_sum_rows;
|
||||||
break;
|
break;
|
||||||
|
@ -14776,7 +15013,6 @@ catch (sycl::exception const &exc) {
|
||||||
|
|
||||||
GGML_CALL static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
GGML_CALL static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||||
return 128;
|
return 128;
|
||||||
|
|
||||||
UNUSED(buft);
|
UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15134,7 +15370,6 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc
|
||||||
|
|
||||||
GGML_CALL static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
GGML_CALL static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||||
return 128;
|
return 128;
|
||||||
|
|
||||||
UNUSED(buft);
|
UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15253,6 +15488,7 @@ static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggm
|
||||||
// FIXME: this is a hack to avoid having to implement a new buffer type
|
// FIXME: this is a hack to avoid having to implement a new buffer type
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||||
buffer->buft = buft;
|
buffer->buft = buft;
|
||||||
|
buffer->iface.get_name = ggml_backend_sycl_host_buffer_name;
|
||||||
buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
|
buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
|
||||||
|
|
||||||
return buffer;
|
return buffer;
|
||||||
|
@ -15366,7 +15602,6 @@ catch (sycl::exception const &exc) {
|
||||||
|
|
||||||
GGML_CALL static bool ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
GGML_CALL static bool ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
||||||
|
|
||||||
ggml_sycl_set_main_device(sycl_ctx->device);
|
ggml_sycl_set_main_device(sycl_ctx->device);
|
||||||
|
|
||||||
ggml_compute_params params = {};
|
ggml_compute_params params = {};
|
||||||
|
@ -15390,7 +15625,6 @@ GGML_CALL static bool ggml_backend_sycl_graph_compute(ggml_backend_t backend, gg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool ok = ggml_sycl_compute_forward(¶ms, node);
|
bool ok = ggml_sycl_compute_forward(¶ms, node);
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||||
|
@ -15485,16 +15719,17 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
||||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_DUP:
|
||||||
|
case GGML_OP_REPEAT:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
{
|
{
|
||||||
ggml_type src0_type = op->src[0]->type;
|
ggml_type src0_type = op->src[0]->type;
|
||||||
if (src0_type == GGML_TYPE_F32) {
|
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
|
@ -15502,8 +15737,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_REPEAT:
|
|
||||||
case GGML_OP_DUP:
|
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
|
@ -15517,6 +15750,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_ALIBI:
|
case GGML_OP_ALIBI:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
@ -15592,13 +15826,13 @@ extern "C" int ggml_backend_sycl_reg_devices();
|
||||||
|
|
||||||
int ggml_backend_sycl_reg_devices() {
|
int ggml_backend_sycl_reg_devices() {
|
||||||
if (!g_sycl_gpu_mgr) g_sycl_gpu_mgr = new sycl_gpu_mgr();
|
if (!g_sycl_gpu_mgr) g_sycl_gpu_mgr = new sycl_gpu_mgr();
|
||||||
int device_count = g_sycl_gpu_mgr->get_gpu_count();
|
g_device_count = g_sycl_gpu_mgr->get_gpu_count();
|
||||||
|
assert(g_device_count>0);
|
||||||
for (int i = 0; i < device_count; i++) {
|
for (int i = 0; i < g_device_count; i++) {
|
||||||
int id = g_sycl_gpu_mgr->gpus[i];
|
int id = g_sycl_gpu_mgr->gpus[i];
|
||||||
char name[128];
|
char name[128];
|
||||||
snprintf(name, sizeof(name), "%s%d", GGML_SYCL_NAME, id);
|
snprintf(name, sizeof(name), "%s%d", GGML_SYCL_NAME, id);
|
||||||
ggml_backend_register(name, ggml_backend_reg_sycl_init, ggml_backend_sycl_buffer_type(i), (void *) (intptr_t) i);
|
ggml_backend_register(name, ggml_backend_reg_sycl_init, ggml_backend_sycl_buffer_type(i), (void *) (intptr_t) i);
|
||||||
}
|
}
|
||||||
return device_count;
|
return g_device_count;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1411,7 +1411,9 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer
|
||||||
buft = ggml_backend_cuda_host_buffer_type();
|
buft = ggml_backend_cuda_host_buffer_type();
|
||||||
}
|
}
|
||||||
#elif defined(GGML_USE_SYCL)
|
#elif defined(GGML_USE_SYCL)
|
||||||
|
if (host_buffer) {
|
||||||
buft = ggml_backend_sycl_host_buffer_type();
|
buft = ggml_backend_sycl_host_buffer_type();
|
||||||
|
}
|
||||||
#elif defined(GGML_USE_CPU_HBM)
|
#elif defined(GGML_USE_CPU_HBM)
|
||||||
buft = ggml_backend_cpu_hbm_buffer_type();
|
buft = ggml_backend_cpu_hbm_buffer_type();
|
||||||
#elif defined(GGML_USE_VULKAN)
|
#elif defined(GGML_USE_VULKAN)
|
||||||
|
@ -12095,7 +12097,6 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ggml_set_name(ctx->inp_cls, "inp_cls");
|
ggml_set_name(ctx->inp_cls, "inp_cls");
|
||||||
|
|
||||||
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
|
||||||
ggml_backend_buffer_name(ctx->buf_input),
|
ggml_backend_buffer_name(ctx->buf_input),
|
||||||
ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0);
|
ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue