Update vulkan rope implementation to support frequency factors (#7475)
This commit is contained in:
parent
fbf777d2b9
commit
1b1e27cb49
3 changed files with 825 additions and 676 deletions
|
@ -2609,7 +2609,8 @@ layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
|
|||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {int data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 2) readonly buffer Z {float data_freq_factors[];};
|
||||
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
|
@ -2622,6 +2623,7 @@ layout (push_constant) uniform parameter {
|
|||
float corr_dims[4];
|
||||
float theta_scale;
|
||||
float inv_ndims;
|
||||
uint has_freq_facs;
|
||||
} p;
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
|
@ -2671,7 +2673,8 @@ void main() {
|
|||
const float cur_rot = p.inv_ndims * ic - ib;
|
||||
|
||||
const int pos = data_b[i2];
|
||||
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f);
|
||||
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
|
||||
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue