Ported mat mul from Metal
This commit is contained in:
parent
2fc8249ba3
commit
6be93e6071
1 changed files with 47 additions and 141 deletions
188
ggml-vulkan.cpp
188
ggml-vulkan.cpp
|
@ -173,7 +173,7 @@ static std::vector<uint32_t> glsl_compile_source(const std::string& source, cons
|
||||||
std::ofstream fileOut("tmp_kp_shader.comp");
|
std::ofstream fileOut("tmp_kp_shader.comp");
|
||||||
fileOut << source;
|
fileOut << source;
|
||||||
fileOut.close();
|
fileOut.close();
|
||||||
if (system(std::string("glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv > /dev/null").c_str()))
|
if (system("glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv > /dev/null"))
|
||||||
throw std::runtime_error("Error running glslangValidator command");
|
throw std::runtime_error("Error running glslangValidator command");
|
||||||
std::ifstream fileStream("tmp_kp_shader.comp.spv", std::ios::binary);
|
std::ifstream fileStream("tmp_kp_shader.comp.spv", std::ios::binary);
|
||||||
std::vector<char> buffer;
|
std::vector<char> buffer;
|
||||||
|
@ -883,131 +883,59 @@ void ggml_vk_diag_mask_inf(kp::Sequence& seq,
|
||||||
|
|
||||||
static const std::string program_mul_mat_f16 =
|
static const std::string program_mul_mat_f16 =
|
||||||
MULTILINE_QUOTE(
|
MULTILINE_QUOTE(
|
||||||
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 64) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
|
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
|
||||||
layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; };
|
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
int M;
|
int64_t ne00;
|
||||||
int N;
|
int64_t ne01;
|
||||||
int K;
|
uint64_t nb00;
|
||||||
int inAStride;
|
uint64_t nb01;
|
||||||
int inBStride;
|
uint64_t nb02;
|
||||||
int outStride;
|
int64_t ne10;
|
||||||
|
int64_t ne11;
|
||||||
|
uint64_t nb10;
|
||||||
|
uint64_t nb11;
|
||||||
|
uint64_t nb12;
|
||||||
|
int64_t ne0;
|
||||||
|
int64_t ne1;
|
||||||
uint inAOff;
|
uint inAOff;
|
||||||
uint inBOff;
|
uint inBOff;
|
||||||
uint outOff;
|
uint outOff;
|
||||||
} pcs;
|
} pcs;
|
||||||
|
|
||||||
shared float16_t bufA[BM * (BK+1)];
|
shared float sum[gl_WorkGroupSize.x];
|
||||||
shared float16_t bufB[BN * (BK+1)];
|
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const int ir = int(gl_WorkGroupID.x);
|
const int64_t r0 = gl_GlobalInvocationID.x;
|
||||||
const int ic = int(gl_WorkGroupID.y);
|
const int64_t r1 = gl_GlobalInvocationID.y;
|
||||||
|
const int64_t im = gl_GlobalInvocationID.z;
|
||||||
|
|
||||||
const int rstride = BM / TM;
|
const uint x = uint((r0*pcs.nb01 + im*pcs.nb02) / 2); // Based from inA
|
||||||
|
const uint y = uint((r1*pcs.nb11 + im*pcs.nb12) / 4); // based from inB
|
||||||
|
|
||||||
const int lr = int(gl_LocalInvocationID.x % rstride);
|
sum[gl_WorkGroupID.x] = 0.0f;
|
||||||
const int lc = int(gl_LocalInvocationID.x / rstride);
|
|
||||||
|
|
||||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
for (uint i = gl_WorkGroupID.x; i < pcs.ne00; i += gl_WorkGroupSize.x) {
|
||||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
sum[gl_WorkGroupID.x] += float(inA[x+i]) * float(inB[y+i]);
|
||||||
|
|
||||||
const int loadstride = int(gl_WorkGroupSize.x);
|
|
||||||
|
|
||||||
int posA = ir * BM * pcs.inAStride;
|
|
||||||
int posB = ic * BN * pcs.inBStride;
|
|
||||||
|
|
||||||
float sums[TM * TN];
|
|
||||||
float16_t cacheA[TM];
|
|
||||||
float16_t cacheB[TN];
|
|
||||||
|
|
||||||
[[unroll]] for (int i = 0; i < TM*TN; i++) {
|
|
||||||
sums[i] = 0.0hf;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (int block = 0; block < pcs.K; block += BK) {
|
// accumulate the sum from all threads in the threadgroup
|
||||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
barrier();
|
||||||
const int lr = l % BK;
|
memoryBarrierShared();
|
||||||
const int lc = l / BK;
|
for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
|
||||||
bufA[(loadc + lc) * (BK+1) + loadr + lr] = inA[posA + (loadc + lc) * pcs.inAStride + loadr + lr];
|
if (gl_WorkGroupID.x < i) {
|
||||||
|
sum[gl_WorkGroupID.x] += sum[gl_WorkGroupID.x + i];
|
||||||
}
|
}
|
||||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
|
||||||
const int lr = l % BK;
|
|
||||||
const int lc = l / BK;
|
|
||||||
bufB[(loadc + lc) * (BK+1) + loadr + lr] = inB[posB + (loadc + lc) * pcs.inBStride + loadr + lr];
|
|
||||||
}
|
|
||||||
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
posA += BK;
|
|
||||||
posB += BK;
|
|
||||||
|
|
||||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
|
||||||
// Load from shared into cache
|
|
||||||
[[unroll]] for (int j = 0; j < BM; j++) {
|
|
||||||
cacheA[j] = bufA[(lr + j*rstride) * (BK+1) + i];
|
|
||||||
}
|
|
||||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
|
||||||
cacheB[j] = bufB[(lc * TN + j) * (BK+1) + i];
|
|
||||||
}
|
|
||||||
|
|
||||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
|
||||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
|
||||||
sums[cc * TM + cr] += float(cacheA[cr]) * float(cacheB[cc]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
|
memoryBarrierShared();
|
||||||
}
|
}
|
||||||
|
|
||||||
const int dr = ir * BM + lr;
|
if (gl_WorkGroupID.x == 0) {
|
||||||
const int dc = ic * BN + lc * TN;
|
out_[uint(im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0)] = sum[0];
|
||||||
|
|
||||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
|
||||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
|
||||||
out_[(dc + cc) * pcs.outStride + dr + cr*rstride] = sums[cc * TM + cr];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
static const std::string program_fast_mul_mat_f16 =
|
|
||||||
MULTILINE_QUOTE(
|
|
||||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
int M;
|
|
||||||
int N;
|
|
||||||
int K;
|
|
||||||
int inAStride;
|
|
||||||
int inBStride;
|
|
||||||
int outStride;
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
int row = int(gl_GlobalInvocationID.x);
|
|
||||||
int col = int(gl_GlobalInvocationID.y);
|
|
||||||
|
|
||||||
if (row < pcs.M && col < pcs.N) {
|
|
||||||
float sum = 0.0f;
|
|
||||||
|
|
||||||
for (int i = 0; i < pcs.K; i++) {
|
|
||||||
sum += float(inA[row * pcs.inAStride + i]) * float(inB[col * pcs.inBStride + i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
out_[col * pcs.outStride + row] = sum;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -1016,48 +944,26 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq,
|
||||||
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||||
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||||
int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03,
|
int64_t ne00, int64_t ne01,
|
||||||
int64_t ne10, int64_t ne11,
|
uint64_t nb00, uint64_t nb01, uint64_t nb02,
|
||||||
int nb10, int nb11, int nb12, int nb13,
|
int64_t ne10, int64_t ne11, int64_t ne12,
|
||||||
int nb2, int nb3) {
|
uint64_t nb10, uint64_t nb11, uint64_t nb12,
|
||||||
const static auto spirv = glsl_compile_source(program_source_head+program_fast_mul_mat_f16, __func__);
|
int64_t ne0, int64_t ne1) {
|
||||||
|
const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__);
|
||||||
const bool inB_cont_rows = nb10 == sizeof(float);
|
|
||||||
const bool inB_cont_cols = (size_t)nb11 == ne11 * sizeof(float);
|
|
||||||
|
|
||||||
struct PushConstants {
|
struct PushConstants {
|
||||||
int32_t M, N, K, inAStride, inBStride, outStride;
|
int64_t ne00, ne01;
|
||||||
|
uint64_t nb00, nb01, nb02;
|
||||||
|
int64_t ne10, ne11;
|
||||||
|
uint64_t nb10, nb11, nb12;
|
||||||
|
int64_t ne0, ne1;
|
||||||
uint32_t inAOff, inBOff, outOff;
|
uint32_t inAOff, inBOff, outOff;
|
||||||
} pushConsts {
|
} pushConsts {
|
||||||
(int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01,
|
ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, ne0, ne1,
|
||||||
inAOff, inBOff, outOff
|
inAOff, inBOff, outOff
|
||||||
};
|
};
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne11), unsigned(ne12)}, {}, {pushConsts}));
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
||||||
auto tmp = mgr.tensorT<half>(std::vector<half>(ne10*ne11));
|
|
||||||
|
|
||||||
if (inB_cont_rows) {
|
|
||||||
if (inB_cont_cols) {
|
|
||||||
ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12)/sizeof(float), tmp, 0, ne10*ne11);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
|
||||||
ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11)/sizeof(float), tmp, i01*ne10, ne10);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
|
||||||
for (int64_t i00 = 0; i00 < ne10; i00++) {
|
|
||||||
// Extremely slow because of single shader invocation
|
|
||||||
ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)/sizeof(float), tmp, i01*ne10 + i00, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, tmp, out}, spirv, {uint32_t(ne01/128), uint32_t(ne11/128)}, {}, {pushConsts}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1179,7 +1085,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||||
{
|
{
|
||||||
if (src0->type == GGML_TYPE_F16
|
if (src0->type == GGML_TYPE_F16
|
||||||
&& src1->type == GGML_TYPE_F32) {
|
&& src1->type == GGML_TYPE_F32) {
|
||||||
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3);
|
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue