vulkan: Use push constant offset to handle misaligned descriptors (#10987)

This commit is contained in:
Jeff Bolz 2024-12-29 02:35:11 -06:00 committed by GitHub
parent f865ea149d
commit fdd2188912
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 103 additions and 42 deletions

View file

@ -21,9 +21,9 @@ void main() {
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
} else {
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
}
}

View file

@ -22,7 +22,7 @@ void main() {
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}

View file

@ -12,6 +12,6 @@ void main() {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}

View file

@ -30,12 +30,12 @@ void main() {
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]);
#else
if (is_src0) {
data_d[p.d_offset + dst_idx] = data_a[src0_idx];
data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx];
} else {
data_d[p.d_offset + dst_idx] = data_b[src1_idx];
data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx];
}
#endif
}

View file

@ -19,9 +19,9 @@ void main() {
if (idx + (num_iter-1)*num_threads < p.ne) {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[p.d_offset + idx] = data_a[idx];
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}
@ -32,9 +32,9 @@ void main() {
}
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[p.d_offset + idx] = data_a[idx];
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}

View file

@ -13,8 +13,8 @@ void main() {
}
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
#else
data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
#endif
}

View file

@ -12,6 +12,6 @@ void main() {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
}

View file

@ -20,7 +20,7 @@ void main() {
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}

View file

@ -7,7 +7,7 @@ layout (push_constant) uniform parameter
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
uint d_offset;
uint misalign_offsets;
float param1; float param2; int param3;
} p;
@ -22,6 +22,10 @@ uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
uint get_doffset() { return p.misalign_offsets & 0xFF; }
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
uint fastmod(uint a, uint b) {
if ((b & (b-1)) == 0) {

View file

@ -6,7 +6,7 @@ layout (push_constant) uniform parameter
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint d_offset;
uint misalign_offsets;
float param1; float param2;
uint ne0_012mp; uint ne0_012L;
@ -24,6 +24,9 @@ uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;

View file

@ -15,10 +15,10 @@ void main() {
return;
}
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);

View file

@ -20,7 +20,7 @@ void main() {
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}

View file

@ -24,5 +24,5 @@ void main() {
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
}

View file

@ -22,5 +22,5 @@ void main() {
return;
}
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]);
}

View file

@ -18,7 +18,7 @@ void main() {
continue;
}
data_d[p.d_offset + idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1));
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
idx += num_threads;
}
}

View file

@ -12,6 +12,6 @@ void main() {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
}

View file

@ -12,6 +12,6 @@ void main() {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
}

View file

@ -2,7 +2,7 @@
layout (push_constant) uniform parameter
{
uint ne; uint d_offset;
uint ne; uint a_offset; uint d_offset;
uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13;
float sf0; float sf1; float sf2; float sf3;
@ -32,5 +32,5 @@ void main() {
const uint i02 = uint(i12 / p.sf2);
const uint i03 = uint(i13 / p.sf3);
data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
}