ggml : add GGML_PAD_REFLECT_1D operation (ggml/1034)

* ggml_pad_reflect_1d defined in header

* implemented on CPU

* called the forward pass

* impl Metal kernel

* added Metal kernel

* added OP_PAD_REFLECT_1D in test-backend-ops.cpp

* add test-pad-reflect-1d test case

* test case support multiple backend
This commit is contained in:
PAB 2024-12-03 20:20:04 +01:00 committed by Georgi Gerganov
parent d405804be8
commit c2082d93a8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 192 additions and 2 deletions

View file

@ -10439,6 +10439,40 @@ static void ggml_compute_forward_pad(
}
}
// ggml_compute_forward_pad_reflect_1d
static void ggml_compute_forward_pad_reflect_1d(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int ith = params->ith;
const int nth = params->nth;
const int32_t * opts = (const int32_t *) dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];
GGML_TENSOR_UNARY_OP_LOCALS
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
}
}
}
}
// ggml_compute_forward_arange
@ -12535,6 +12569,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_pad(params, tensor);
} break;
case GGML_OP_PAD_REFLECT_1D:
{
ggml_compute_forward_pad_reflect_1d(params, tensor);
} break;
case GGML_OP_ARANGE:
{
ggml_compute_forward_arange(params, tensor);
@ -12877,6 +12915,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT: