Minor MUL_MAT fix and implemented DIAG_MASK_INF

This commit is contained in:
niansa 2023-06-30 12:19:29 +02:00
parent 964fe8c546
commit f093bf2e5e

View file

@ -223,8 +223,8 @@ static const std::string program_source_head = R"(#version 450
#define QR4_0 2
#define QK4_1 32
#define GELU_COEF_A 0.044715;
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876;
#define GELU_COEF_A 0.044715
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
#ifndef QK_K
#define QK_K 256
@ -235,6 +235,12 @@ static const std::string program_source_head = R"(#version 450
#else
#define K_SCALE_SIZE 4
#endif
#define BM 128
#define BN 128
#define BK 8
#define TM 8
#define TN 8
)";
@ -651,13 +657,56 @@ void ggml_vk_soft_max(kp::Sequence& seq,
}
static const std::string program_mul_mat_f16 = R"(
#define BM 128
#define BN 128
#define BK 8
#define TM 8
#define TN 8
)" MULTILINE_QUOTE(
static const std::string program_diag_mask_inf =
MULTILINE_QUOTE(
layout(push_constant) uniform PushConstants {
uint64_t ne00;
uint64_t ne01;
uint inAOff;
uint inBOff;
uint outOff;
} pcs;
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint64_t i02 = uint64_t(gl_GlobalInvocationID.z);
const uint64_t i01 = uint64_t(gl_GlobalInvocationID.y);
const uint64_t i00 = uint64_t(gl_GlobalInvocationID.x);
const int n_past = inB[pcs.inBOff];
if (i00 > n_past + i01) {
out_[uint(i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00 + pcs.outOff)] = uintBitsToFloat(0xFF800000);
} else {
out_[uint(i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00 + pcs.outOff)] = inA[uint(i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00 + pcs.inAOff)];
}
}
);
void ggml_vk_diag_mask_inf(kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
int64_t ne00, int64_t ne01, int64_t ne02) {
const static auto spirv = glsl_compile_source(program_source_head+program_diag_mask_inf, __func__);
struct PushConstants {
int64_t ne00, ne01;
uint32_t inAOff, inBOff, outOff;
} pushConsts {
ne00, ne01, inAOff, inBOff, outOff
};
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, inB, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts}));
}
static const std::string program_mul_mat_f16 =
MULTILINE_QUOTE(
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
@ -800,13 +849,8 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq,
}
static const std::string program_mul_mat_f32 = R"(
#define BM 128
#define BN 128
#define BK 8
#define TM 8
#define TN 8
)" MULTILINE_QUOTE(
static const std::string program_mul_mat_f32 =
MULTILINE_QUOTE(
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { float inA[]; };
@ -1041,14 +1085,18 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{
ggml_vk_soft_max(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ne02, ne03);
} break;
case GGML_OP_DIAG_MASK_INF:
{
ggml_vk_diag_mask_inf(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02);
} break;
case GGML_OP_MUL_MAT:
{
if (src0->type == GGML_TYPE_F32
&& src1->type == GGML_TYPE_F32) {
ggml_vk_mul_mat_f32(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb2, nb3);
break;
} else if (src0->type == GGML_TYPE_F32
&& src1->type == GGML_TYPE_F16) {
} else if (src0->type == GGML_TYPE_F16
&& src1->type == GGML_TYPE_F32) {
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3);
}
}