[SYCL] add concat through dim 1/2 (#8483)

* add concat through dim 1/2
This commit is contained in:
Meng, Hengyu 2024-07-15 19:32:15 +08:00 committed by GitHub
parent 3dfda05956
commit 16bdfa42ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 218 additions and 66 deletions

View file

@ -291,29 +291,6 @@ static void sqr_f32(const float * x, float * dst, const int k,
dst[i] = x[i] * x[i];
}
static void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02,
const sycl::nd_item<3> &item_ct1) {
int nidx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (nidx >= ne0) {
return;
}
// operation
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
if (item_ct1.get_group(0) < ne02) { // src0
int offset_src =
nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
nidx + item_ct1.get_group(1) * ne0 +
(item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
dst[offset_dst] = y[offset_src];
}
}
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1,
@ -1347,20 +1324,6 @@ static void sqr_f32_sycl(const float *x, float *dst, const int k,
});
}
static void concat_f32_sycl(const float *x, const float *y, float *dst,
const int ne0, int ne1, int ne2, int ne02,
queue_ptr stream) {
int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
sycl::range<3> gridDim(ne2, ne1, num_blocks);
stream->parallel_for(
sycl::nd_range<3>(gridDim *
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
concat_f32(x, y, dst, ne0, ne02, item_ct1);
});
}
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
const int ne12, const int ne13, const float sf0, const float sf1,
@ -2429,28 +2392,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor
(void) src1_dd;
}
inline void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
#pragma message("TODO: generalize concat kernel for dim != 2")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
int dim = dst->op_params[0];
GGML_ASSERT(dim == 2);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_sycl(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
}
(void) src1;
(void) dst;
}
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
@ -3359,12 +3300,6 @@ static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_ten
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_concat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_concat);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
@ -4101,7 +4036,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
func = ggml_sycl_group_norm;
break;
case GGML_OP_CONCAT:
func = ggml_sycl_concat;
func = ggml_sycl_op_concat;
break;
case GGML_OP_UPSCALE:
func = ggml_sycl_upscale;