implement ggml_repeat support for rank > 2 tensors
This commit is contained in:
parent
7a15a8370c
commit
e6186d98a5
1 changed files with 42 additions and 21 deletions
63
ggml.c
63
ggml.c
|
@ -8854,37 +8854,58 @@ static void ggml_compute_forward_repeat_f32(
|
|||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(params->ith == 0);
|
||||
assert(ggml_can_repeat(src0, dst));
|
||||
GGML_ASSERT(params->ith == 0);
|
||||
GGML_ASSERT(ggml_can_repeat(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: implement support for rank > 2 tensors
|
||||
assert(src0->ne[2] == 1);
|
||||
assert(src0->ne[3] == 1);
|
||||
assert( dst->ne[2] == 1);
|
||||
assert( dst->ne[3] == 1);
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
|
||||
const int nc = dst->ne[0];
|
||||
const int nr = dst->ne[1];
|
||||
const int nc0 = src0->ne[0];
|
||||
const int nr0 = src0->ne[1];
|
||||
const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
const size_t nb2 = dst->nb[2];
|
||||
const size_t nb3 = dst->nb[3];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
const size_t nb02 = src0->nb[2];
|
||||
const size_t nb03 = src0->nb[3];
|
||||
|
||||
// guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const int nr0 = (int)(ne0/ne00);
|
||||
const int nr1 = (int)(ne1/ne01);
|
||||
const int nr2 = (int)(ne2/ne02);
|
||||
const int nr3 = (int)(ne3/ne03);
|
||||
|
||||
// TODO: support for transposed / permuted tensors
|
||||
assert( dst->nb[0] == sizeof(float));
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
// TODO: maybe this is not optimal?
|
||||
for (int i = 0; i < nrr; i++) {
|
||||
for (int j = 0; j < ncr; j++) {
|
||||
for (int k = 0; k < nr0; k++) {
|
||||
ggml_vec_cpy_f32(nc0,
|
||||
(float *) ((char *) dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])),
|
||||
(float *) ((char *) src0->data + ( k)*(src0->nb[1])));
|
||||
for (int i3 = 0; i3 < nr3; i3++) {
|
||||
for (int k3 = 0; k3 < ne03; k3++) {
|
||||
for (int i2 = 0; i2 < nr2; i2++) {
|
||||
for (int k2 = 0; k2 < ne02; k2++) {
|
||||
for (int i1 = 0; i1 < nr1; i1++) {
|
||||
for (int k1 = 0; k1 < ne01; k1++) {
|
||||
for (int i0 = 0; i0 < nr0; i0++) {
|
||||
ggml_vec_cpy_f32(ne00,
|
||||
(float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
|
||||
(float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue