kompute : support mask parameter of softmax
This commit is contained in:
parent
8bd38fe32d
commit
df687b10ab
2 changed files with 24 additions and 14 deletions
|
@ -764,9 +764,10 @@ static void ggml_vk_gelu(Args&&... args) {
|
||||||
|
|
||||||
static void ggml_vk_soft_max(
|
static void ggml_vk_soft_max(
|
||||||
kp::Sequence& seq,
|
kp::Sequence& seq,
|
||||||
const std::shared_ptr<kp::Tensor>& in,
|
const std::shared_ptr<kp::Tensor>& inA,
|
||||||
|
const std::shared_ptr<kp::Tensor>& inB,
|
||||||
const std::shared_ptr<kp::Tensor>& out,
|
const std::shared_ptr<kp::Tensor>& out,
|
||||||
uint32_t inOff, uint32_t outOff,
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||||
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
|
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
|
||||||
float scale
|
float scale
|
||||||
) {
|
) {
|
||||||
|
@ -774,22 +775,27 @@ static void ggml_vk_soft_max(
|
||||||
kp::shader_data::op_softmax_comp_spv_len);
|
kp::shader_data::op_softmax_comp_spv_len);
|
||||||
|
|
||||||
struct PushConstants {
|
struct PushConstants {
|
||||||
uint32_t inOff, outOff;
|
uint32_t inAOff, inBOff, outOff;
|
||||||
int32_t ne00, ne01, ne02;
|
int32_t ne00, ne01, ne02;
|
||||||
float scale;
|
float scale;
|
||||||
|
int32_t mask;
|
||||||
} pushConsts {
|
} pushConsts {
|
||||||
safe_divide(inOff, 4), safe_divide(outOff, 4),
|
safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
||||||
ne00, ne01, ne02, scale
|
ne00, ne01, ne02,
|
||||||
|
scale,
|
||||||
|
bool(inB)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto & inB_ = inB ? inB : inA;
|
||||||
|
|
||||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||||
if (!komputeManager()->hasAlgorithm(__func__)) {
|
if (!komputeManager()->hasAlgorithm(__func__)) {
|
||||||
// FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
|
// FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
|
||||||
const uint32_t local_x = 32;
|
const uint32_t local_x = 32;
|
||||||
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
|
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
|
||||||
} else {
|
} else {
|
||||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||||
s_algo->setTensors({in, out});
|
s_algo->setTensors({inA, inB_, out});
|
||||||
s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
|
s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
|
||||||
s_algo->setPushConstants<PushConstants>({pushConsts});
|
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||||
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
||||||
|
@ -1552,7 +1558,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
const float scale = ((float *) dst->op_params)[0];
|
const float scale = ((float *) dst->op_params)[0];
|
||||||
ggml_vk_soft_max(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, scale);
|
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
|
|
|
@ -6,16 +6,19 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0) in;
|
layout(local_size_x_id = 0) in;
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
|
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
||||||
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
|
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
|
||||||
|
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
layout(push_constant) uniform PushConstants {
|
||||||
uint inOff;
|
uint inAOff;
|
||||||
|
uint inBOff;
|
||||||
uint outOff;
|
uint outOff;
|
||||||
int ne00;
|
int ne00;
|
||||||
int ne01;
|
int ne01;
|
||||||
int ne02;
|
int ne02;
|
||||||
float scale;
|
float scale;
|
||||||
|
int mask;
|
||||||
} pcs;
|
} pcs;
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
@ -27,20 +30,21 @@ void main() {
|
||||||
const uint i01 = gl_WorkGroupID.x;
|
const uint i01 = gl_WorkGroupID.x;
|
||||||
|
|
||||||
const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
|
const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
|
||||||
const uint psrc0 = extra_off + pcs.inOff; // Based from in_
|
const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
|
||||||
|
const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
|
||||||
const uint pdst = extra_off + pcs.outOff; // Based from out_
|
const uint pdst = extra_off + pcs.outOff; // Based from out_
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float localMax = uintBitsToFloat(0xFF800000);
|
float localMax = uintBitsToFloat(0xFF800000);
|
||||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||||
localMax = max(localMax, in_[psrc0 + i00]*pcs.scale);
|
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
|
||||||
}
|
}
|
||||||
float max_ = subgroupMax(localMax);
|
float max_ = subgroupMax(localMax);
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float localSum = 0.0f;
|
float localSum = 0.0f;
|
||||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||||
const float exp_psrc0 = exp(in_[psrc0 + i00]*pcs.scale - max_);
|
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
|
||||||
localSum += exp_psrc0;
|
localSum += exp_psrc0;
|
||||||
out_[pdst + i00] = exp_psrc0;
|
out_[pdst + i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue