kompute: softmax: implement ALiBi support
Signed-off-by: Sergio Lopez <slp@redhat.com>
This commit is contained in:
parent
913536f2a5
commit
d8889598d6
3 changed files with 34 additions and 9 deletions
|
@ -788,7 +788,8 @@ static void ggml_vk_soft_max(
|
||||||
const std::shared_ptr<kp::Tensor>& out,
|
const std::shared_ptr<kp::Tensor>& out,
|
||||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||||
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
|
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
|
||||||
float scale
|
float scale, float max_bias, float m0, float m1,
|
||||||
|
uint32_t n_head_log2
|
||||||
) {
|
) {
|
||||||
const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
|
const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
|
||||||
kp::shader_data::op_softmax_comp_spv_len);
|
kp::shader_data::op_softmax_comp_spv_len);
|
||||||
|
@ -796,12 +797,14 @@ static void ggml_vk_soft_max(
|
||||||
struct PushConstants {
|
struct PushConstants {
|
||||||
uint32_t inAOff, inBOff, outOff;
|
uint32_t inAOff, inBOff, outOff;
|
||||||
int32_t ne00, ne01, ne02;
|
int32_t ne00, ne01, ne02;
|
||||||
float scale;
|
float scale, max_bias, m0, m1;
|
||||||
|
uint32_t n_head_log2;
|
||||||
int32_t mask;
|
int32_t mask;
|
||||||
} pushConsts {
|
} pushConsts {
|
||||||
safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
||||||
ne00, ne01, ne02,
|
ne00, ne01, ne02,
|
||||||
scale,
|
scale, max_bias, m0, m1,
|
||||||
|
n_head_log2,
|
||||||
bool(inB)
|
bool(inB)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1597,11 +1600,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||||
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
||||||
|
|
||||||
#pragma message("TODO: add ALiBi support")
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
|
const int64_t nrows_y = src0->ne[1];
|
||||||
GGML_ASSERT(max_bias == 0.0f);
|
|
||||||
|
|
||||||
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
const uint32_t n_head = nrows_x/nrows_y;
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
|
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
|
||||||
#extension GL_EXT_control_flow_attributes: enable
|
#extension GL_EXT_control_flow_attributes: enable
|
||||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||||
#extension GL_EXT_debug_printf : enable
|
#extension GL_EXT_debug_printf : enable
|
||||||
|
|
|
@ -18,6 +18,10 @@ layout(push_constant) uniform PushConstants {
|
||||||
int ne01;
|
int ne01;
|
||||||
int ne02;
|
int ne02;
|
||||||
float scale;
|
float scale;
|
||||||
|
float max_bias;
|
||||||
|
float m0;
|
||||||
|
float m1;
|
||||||
|
uint n_head_log2;
|
||||||
int mask;
|
int mask;
|
||||||
} pcs;
|
} pcs;
|
||||||
|
|
||||||
|
@ -34,17 +38,29 @@ void main() {
|
||||||
const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
|
const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
|
||||||
const uint pdst = extra_off + pcs.outOff; // Based from out_
|
const uint pdst = extra_off + pcs.outOff; // Based from out_
|
||||||
|
|
||||||
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (pcs.max_bias > 0.0f) {
|
||||||
|
int64_t h = i02;
|
||||||
|
|
||||||
|
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
|
||||||
|
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
|
||||||
|
|
||||||
|
slope = pow(base, float(exp));
|
||||||
|
}
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float localMax = uintBitsToFloat(0xFF800000);
|
float localMax = uintBitsToFloat(0xFF800000);
|
||||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||||
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
|
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
|
||||||
}
|
}
|
||||||
float max_ = subgroupMax(localMax);
|
float max_ = subgroupMax(localMax);
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float localSum = 0.0f;
|
float localSum = 0.0f;
|
||||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||||
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
|
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
|
||||||
localSum += exp_psrc0;
|
localSum += exp_psrc0;
|
||||||
out_[pdst + i00] = exp_psrc0;
|
out_[pdst + i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue