implement ggml_repeat support for rank > 2 tensors

This commit is contained in:
xaedes 2023-05-06 18:01:17 +02:00
parent 7a15a8370c
commit e6186d98a5
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

63
ggml.c
View file

@ -8854,37 +8854,58 @@ static void ggml_compute_forward_repeat_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
assert(params->ith == 0);
assert(ggml_can_repeat(src0, dst));
GGML_ASSERT(params->ith == 0);
GGML_ASSERT(ggml_can_repeat(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
// TODO: implement support for rank > 2 tensors
assert(src0->ne[2] == 1);
assert(src0->ne[3] == 1);
assert( dst->ne[2] == 1);
assert( dst->ne[3] == 1);
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
const int nc = dst->ne[0];
const int nr = dst->ne[1];
const int nc0 = src0->ne[0];
const int nr0 = src0->ne[1];
const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const size_t nb0 = dst->nb[0];
const size_t nb1 = dst->nb[1];
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];
const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
const size_t nb03 = src0->nb[3];
// guaranteed to be an integer due to the check in ggml_can_repeat
const int nr0 = (int)(ne0/ne00);
const int nr1 = (int)(ne1/ne01);
const int nr2 = (int)(ne2/ne02);
const int nr3 = (int)(ne3/ne03);
// TODO: support for transposed / permuted tensors
assert( dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
// TODO: maybe this is not optimal?
for (int i = 0; i < nrr; i++) {
for (int j = 0; j < ncr; j++) {
for (int k = 0; k < nr0; k++) {
ggml_vec_cpy_f32(nc0,
(float *) ((char *) dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])),
(float *) ((char *) src0->data + ( k)*(src0->nb[1])));
for (int i3 = 0; i3 < nr3; i3++) {
for (int k3 = 0; k3 < ne03; k3++) {
for (int i2 = 0; i2 < nr2; i2++) {
for (int k2 = 0; k2 < ne02; k2++) {
for (int i1 = 0; i1 < nr1; i1++) {
for (int k1 = 0; k1 < ne01; k1++) {
for (int i0 = 0; i0 < nr0; i0++) {
ggml_vec_cpy_f32(ne00,
(float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
(float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
}
}
}
}
}
}
}