vulkan: implement GGML_OP_REPEAT_BACK

This commit is contained in:
Rémy O 2025-02-09 10:07:20 +01:00
parent e6a2c06bbb
commit bc349762d8
3 changed files with 72 additions and 1 deletions

View file

@ -0,0 +1,37 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
// Destination multi-index (inlined dst_idx)
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
// Accumulate from sources
A_TYPE acc = A_TYPE(0);
for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
}
}
}
}
data_d[get_doffset() + d_idx] = D_TYPE(acc);
}

View file

@ -445,6 +445,7 @@ void process_shaders() {
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});