Update vulkan rope implementation to support frequency factors (#7475)

This commit is contained in:
0cc4m 2024-05-23 08:59:59 +02:00 committed by GitHub
parent fbf777d2b9
commit 1b1e27cb49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 825 additions and 676 deletions

View file

@ -2609,7 +2609,8 @@ layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {int data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 2) readonly buffer Z {float data_freq_factors[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
@ -2622,6 +2623,7 @@ layout (push_constant) uniform parameter {
float corr_dims[4];
float theta_scale;
float inv_ndims;
uint has_freq_facs;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {
@ -2671,7 +2673,8 @@ void main() {
const float cur_rot = p.inv_ndims * ic - ib;
const int pos = data_b[i2];
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f);
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
float cos_theta, sin_theta;
rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);