parent
3dfda05956
commit
16bdfa42ac
4 changed files with 218 additions and 66 deletions
|
@ -291,29 +291,6 @@ static void sqr_f32(const float * x, float * dst, const int k,
|
|||
dst[i] = x[i] * x[i];
|
||||
}
|
||||
|
||||
static void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
// operation
|
||||
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
if (item_ct1.get_group(0) < ne02) { // src0
|
||||
int offset_src =
|
||||
nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
int offset_src =
|
||||
nidx + item_ct1.get_group(1) * ne0 +
|
||||
(item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
|
||||
dst[offset_dst] = y[offset_src];
|
||||
}
|
||||
}
|
||||
|
||||
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
|
@ -1347,20 +1324,6 @@ static void sqr_f32_sycl(const float *x, float *dst, const int k,
|
|||
});
|
||||
}
|
||||
|
||||
static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int ne0, int ne1, int ne2, int ne02,
|
||||
queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
|
||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim *
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
concat_f32(x, y, dst, ne0, ne02, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
|
@ -2429,28 +2392,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
#pragma message("TODO: generalize concat kernel for dim != 2")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
|
||||
int dim = dst->op_params[0];
|
||||
GGML_ASSERT(dim == 2);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
concat_f32_sycl(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
|
||||
}
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
|
@ -3359,12 +3300,6 @@ static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_ten
|
|||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_concat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_concat);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
|
||||
|
@ -4101,7 +4036,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|||
func = ggml_sycl_group_norm;
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
func = ggml_sycl_concat;
|
||||
func = ggml_sycl_op_concat;
|
||||
break;
|
||||
case GGML_OP_UPSCALE:
|
||||
func = ggml_sycl_upscale;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue