kompute: rope: implement neox and phi3 support

Signed-off-by: Sergio Lopez <slp@redhat.com>
This commit is contained in:
Sergio Lopez 2024-11-22 14:35:50 +01:00
parent d8889598d6
commit 1b8afa88dc
9 changed files with 258 additions and 176 deletions

View file

@ -105,8 +105,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
kompute-shaders/op_getrows_q4_0.comp kompute-shaders/op_getrows_q4_0.comp
kompute-shaders/op_getrows_q4_1.comp kompute-shaders/op_getrows_q4_1.comp
kompute-shaders/op_getrows_q6_k.comp kompute-shaders/op_getrows_q6_k.comp
kompute-shaders/op_rope_f16.comp kompute-shaders/op_rope_norm_f16.comp
kompute-shaders/op_rope_f32.comp kompute-shaders/op_rope_norm_f32.comp
kompute-shaders/op_rope_neox_f16.comp
kompute-shaders/op_rope_neox_f32.comp
kompute-shaders/op_cpy_f16_f16.comp kompute-shaders/op_cpy_f16_f16.comp
kompute-shaders/op_cpy_f16_f32.comp kompute-shaders/op_cpy_f16_f32.comp
kompute-shaders/op_cpy_f32_f16.comp kompute-shaders/op_cpy_f32_f16.comp
@ -139,8 +141,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
shaderop_getrows_q4_0.h shaderop_getrows_q4_0.h
shaderop_getrows_q4_1.h shaderop_getrows_q4_1.h
shaderop_getrows_q6_k.h shaderop_getrows_q6_k.h
shaderop_rope_f16.h shaderop_rope_norm_f16.h
shaderop_rope_f32.h shaderop_rope_norm_f32.h
shaderop_rope_neox_f16.h
shaderop_rope_neox_f32.h
shaderop_cpy_f16_f16.h shaderop_cpy_f16_f16.h
shaderop_cpy_f16_f32.h shaderop_cpy_f16_f32.h
shaderop_cpy_f32_f16.h shaderop_cpy_f32_f16.h

View file

@ -28,8 +28,10 @@
#include "shaderop_getrows_q4_0.h" #include "shaderop_getrows_q4_0.h"
#include "shaderop_getrows_q4_1.h" #include "shaderop_getrows_q4_1.h"
#include "shaderop_getrows_q6_k.h" #include "shaderop_getrows_q6_k.h"
#include "shaderop_rope_f16.h" #include "shaderop_rope_norm_f16.h"
#include "shaderop_rope_f32.h" #include "shaderop_rope_norm_f32.h"
#include "shaderop_rope_neox_f16.h"
#include "shaderop_rope_neox_f32.h"
#include "shaderop_cpy_f16_f16.h" #include "shaderop_cpy_f16_f16.h"
#include "shaderop_cpy_f16_f32.h" #include "shaderop_cpy_f16_f32.h"
#include "shaderop_cpy_f32_f16.h" #include "shaderop_cpy_f32_f16.h"
@ -345,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = { std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
vk::DescriptorPoolSize( vk::DescriptorPoolSize(
vk::DescriptorType::eStorageBuffer, vk::DescriptorType::eStorageBuffer,
3 * size // Descriptor count is number of possible tensors to pass into an algorithm 4 * size // Descriptor count is number of possible tensors to pass into an algorithm
) )
}; };
@ -1220,10 +1222,11 @@ static void ggml_vk_rope(
kp::Sequence& seq, kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA, const std::shared_ptr<kp::Tensor>& inA,
const std::shared_ptr<kp::Tensor>& inB, const std::shared_ptr<kp::Tensor>& inB,
const std::shared_ptr<kp::Tensor>& inC,
const std::shared_ptr<kp::Tensor>& out, const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff, uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig, ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
int32_t ne01, int32_t ne02, int32_t ne03, int32_t ne01, int32_t ne02, int32_t ne03,
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
int32_t ne0, int32_t ne0,
@ -1231,11 +1234,17 @@ static void ggml_vk_rope(
) { ) {
GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32); GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
static const auto spirv_f16 = getSpirvShader( static const auto spirv_norm_f16 = getSpirvShader(
kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
); );
static const auto spirv_f32 = getSpirvShader( static const auto spirv_norm_f32 = getSpirvShader(
kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
);
static const auto spirv_neox_f16 = getSpirvShader(
kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
);
static const auto spirv_neox_f32 = getSpirvShader(
kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
); );
int type_size = src0t == GGML_TYPE_F16 ? 2 : 4; int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
@ -1250,32 +1259,40 @@ static void ggml_vk_rope(
GGML_ASSERT(nb0 % type_size == 0); GGML_ASSERT(nb0 % type_size == 0);
struct PushConstants { struct PushConstants {
uint32_t inAOff, inBOff, outOff; uint32_t inAOff, inBOff, inCOff, outOff;
int32_t n_dims, mode, n_ctx_orig; int32_t n_dims, mode, n_ctx_orig;
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale;
bool has_freq_factors;
float ext_factor, attn_factor, beta_fast, beta_slow;
uint32_t nb00, nb01, nb02, nb03; uint32_t nb00, nb01, nb02, nb03;
int32_t ne0; int32_t ne0;
uint32_t nb0, nb1, nb2, nb3; uint32_t nb0, nb1, nb2, nb3;
} pushConsts { } pushConsts {
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size), safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
n_dims, mode, n_ctx_orig, n_dims, mode, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, freq_base, freq_scale,
has_freq_factors,
ext_factor, attn_factor, beta_fast, beta_slow,
nb00, nb01, nb02, nb03, nb00, nb01, nb02, nb03,
ne0, ne0,
nb0, nb1, nb2, nb3 nb0, nb1, nb2, nb3
}; };
auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32"); auto & inC_ = inC ? inC : inA;
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_f16 = src0t == GGML_TYPE_F16;
auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
std::shared_ptr<kp::Algorithm> s_algo = nullptr; std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(name)) { if (!komputeManager()->hasAlgorithm(name)) {
auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
s_algo = komputeManager()->algorithm<float, PushConstants>( s_algo = komputeManager()->algorithm<float, PushConstants>(
name, s_kompute_context->pool.get(), {inA, inB, out}, name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
{unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts} {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
); );
} else { } else {
s_algo = komputeManager()->getAlgorithm(name); s_algo = komputeManager()->getAlgorithm(name);
s_algo->setTensors({inA, inB, out}); s_algo->setTensors({inA, inB, inC_, out});
s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
s_algo->setPushConstants<PushConstants>({pushConsts}); s_algo->setPushConstants<PushConstants>({pushConsts});
s_algo->updateDescriptors(s_kompute_context->pool.get()); s_algo->updateDescriptors(s_kompute_context->pool.get());
@ -1522,9 +1539,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
const static std::shared_ptr<kp::Tensor> nullTensor = nullptr; const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
uint32_t off_src0 = 0; uint32_t off_src0 = 0;
uint32_t off_src1 = 0; uint32_t off_src1 = 0;
uint32_t off_src2 = 0;
uint32_t off_dst = 0; uint32_t off_dst = 0;
const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor; const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor; const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor; const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
switch (dst->op) { switch (dst->op) {
@ -1721,13 +1740,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
} break; } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
#pragma message("TODO: implement phi3 frequency factors support")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
#pragma message("TODO: update rope NORM mode to match NEOX mode")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
GGML_ASSERT(ne10 == ne02); GGML_ASSERT(ne10 == ne02);
GGML_ASSERT(src0t == dstt); GGML_ASSERT(src0t == dstt);
// const int n_past = ((int32_t *) dst->op_params)[0]; // const int n_past = ((int32_t *) dst->op_params)[0];
@ -1736,6 +1748,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
const bool has_freq_factors = dst->src[2] != nullptr;
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@ -1744,8 +1758,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
ggml_vk_rope( ggml_vk_rope(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig, seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3 ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
); );
} break; } break;

View file

@ -1,73 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
const int p = inB[pcs.inBOff + i2];
float theta = float(p);
if (!is_neox) {
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
theta *= theta_scale;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+1]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
}
} else {
const float inv_ndims = -1.f/pcs.n_dims;
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
const uint cur_rot = ic;
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
theta *= theta_scale;
const uint i0 = ic/2;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+pcs.n_dims/2]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
}
for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
const uint i0 = ic;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data + 0] = inA[src + 0];
out_[dst_data + 1] = inA[src + 1];
}
}
}

View file

@ -1,73 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
const int p = inB[pcs.inBOff + i2];
float theta = float(p);
if (!is_neox) {
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
theta *= theta_scale;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+1];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
}
} else {
const float inv_ndims = -1.f/pcs.n_dims;
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
const uint cur_rot = ic;
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
theta *= theta_scale;
const uint i0 = ic/2;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+pcs.n_dims/2];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
}
for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
const uint i0 = ic;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data + 0] = inA[src + 0];
out_[dst_data + 1] = inA[src + 1];
}
}
}

View file

@ -0,0 +1,52 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+pcs.n_dims/2]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View file

@ -0,0 +1,52 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+pcs.n_dims/2];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View file

@ -0,0 +1,52 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+1]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View file

@ -0,0 +1,52 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+1];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View file

@ -8,12 +8,14 @@ layout(local_size_x = 1) in;
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint inAOff; uint inAOff;
uint inBOff; uint inBOff;
uint inCOff;
uint outOff; uint outOff;
int n_dims; int n_dims;
int mode; int mode;
int n_ctx_orig; int n_ctx_orig;
float freq_base; float freq_base;
float freq_scale; float freq_scale;
bool has_freq_factors;
float ext_factor; float ext_factor;
float attn_factor; float attn_factor;
float beta_fast; float beta_fast;