Transfer remaining shaders to header and compile on runtime
This commit is contained in:
parent
a47ca7ae7a
commit
592ebb044d
18 changed files with 429 additions and 1470 deletions
|
@ -347,27 +347,12 @@ if (LLAMA_CLBLAST)
|
|||
endif()
|
||||
|
||||
if (LLAMA_VULKAN)
|
||||
find_package(Vulkan COMPONENTS glslc)
|
||||
find_package(Vulkan COMPONENTS glslc SPIRV-Tools)
|
||||
if (Vulkan_FOUND)
|
||||
message(STATUS "Vulkan found")
|
||||
|
||||
add_library(ggml-vulkan STATIC ggml-vulkan.cpp ggml-vulkan.h)
|
||||
target_link_libraries(ggml-vulkan PUBLIC Vulkan::Vulkan)
|
||||
|
||||
set(GGML_VULKAN_SHADERS matmul_f32 matmul_f16 f16_to_f32 dequant_q4_0)
|
||||
|
||||
foreach(s IN LISTS GGML_VULKAN_SHADERS)
|
||||
add_custom_command(
|
||||
OUTPUT "vk_shaders/${s}.spv"
|
||||
COMMAND "${Vulkan_GLSLC_EXECUTABLE}"
|
||||
-fshader-stage=compute
|
||||
--target-env=vulkan1.2
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vk_shaders/${s}.glsl"
|
||||
-o "${CMAKE_CURRENT_BINARY_DIR}/vk_shaders/${s}.spv"
|
||||
DEPENDS "vk_shaders/${s}.glsl"
|
||||
)
|
||||
target_sources(ggml-vulkan PRIVATE "vk_shaders/${s}.spv")
|
||||
endforeach()
|
||||
target_link_libraries(ggml-vulkan PUBLIC Vulkan::Vulkan SPIRV SPIRV-Tools-opt SPIRV-Tools shaderc_combined)
|
||||
|
||||
add_compile_definitions(GGML_USE_VULKAN)
|
||||
|
||||
|
|
15
Makefile
15
Makefile
|
@ -232,21 +232,6 @@ ifdef LLAMA_VULKAN
|
|||
OBJS += ggml-vulkan.o
|
||||
ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32.glsl -o vk_shaders/matmul_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32_aligned.glsl -o vk_shaders/matmul_f32_aligned.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16.glsl -o vk_shaders/matmul_f16.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_aligned.glsl -o vk_shaders/matmul_f16_aligned.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_f32.glsl -o vk_shaders/matmul_f16_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_f32_aligned.glsl -o vk_shaders/matmul_f16_f32_aligned.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16.glsl -o vk_shaders/dequant_mul_mat_vec_f16.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16_f32.glsl -o vk_shaders/dequant_mul_mat_vec_f16_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f32.glsl -o vk_shaders/add_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f16_f32_f16.glsl -o vk_shaders/add_f16_f32_f16.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/scale_f32.glsl -o vk_shaders/scale_f32.spv & \
|
||||
wait
|
||||
endif
|
||||
|
||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||
|
|
|
@ -1,5 +1,238 @@
|
|||
#include <string>
|
||||
|
||||
// Generic
|
||||
const std::string shader_f32 = R"(
|
||||
#define FLOAT_TYPE float
|
||||
)";
|
||||
const std::string shader_f16 = R"(
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#define FLOAT_TYPE float16_t
|
||||
)";
|
||||
const std::string shader_int8_ext = R"(
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
)";
|
||||
const std::string shader_output_f16 = R"(
|
||||
#define OUT_TYPE float16_t
|
||||
)";
|
||||
const std::string shader_output_f32 = R"(
|
||||
#define OUT_TYPE float
|
||||
)";
|
||||
|
||||
// MULMAT
|
||||
|
||||
const std::string mulmat_head = R"(
|
||||
#version 450
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
)";
|
||||
|
||||
const std::string mulmat_body = R"(
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#ifdef ALIGNED_INPUT
|
||||
#define LOAD_VEC 8
|
||||
layout (binding = 0) readonly buffer A { A_TYPE data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { B_TYPE data_b[]; };
|
||||
|
||||
#else
|
||||
|
||||
#define LOAD_VEC 1
|
||||
layout (binding = 0) readonly buffer A { A_TYPE data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { B_TYPE data_b[]; };
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D { D_TYPE data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
int k_split;
|
||||
} p;
|
||||
|
||||
layout (constant_id = 1) const int BM = 64;
|
||||
layout (constant_id = 2) const int BN = 64;
|
||||
layout (constant_id = 3) const int BK = 16;
|
||||
layout (constant_id = 4) const int WM = 32;
|
||||
layout (constant_id = 5) const int WN = 32;
|
||||
layout (constant_id = 6) const int WMITER = 2;
|
||||
layout (constant_id = 7) const int TM = 4;
|
||||
layout (constant_id = 8) const int TN = 2;
|
||||
|
||||
shared FLOAT_TYPE buf_a[BM * (BK+1)];
|
||||
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int blocks_x = (p.M + BM - 1) / BM;
|
||||
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % (BK / LOAD_VEC));
|
||||
const int loadc = int(gl_LocalInvocationID.x / (BK / LOAD_VEC));
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x * LOAD_VEC) / BK;
|
||||
|
||||
const int start_k = ik * p.k_split;
|
||||
const int end_k = (ik + 1) * p.k_split;
|
||||
|
||||
int pos_a = ir * BM * p.stride_a / LOAD_VEC + start_k / LOAD_VEC;
|
||||
int pos_b = ic * BN * p.stride_b / LOAD_VEC + start_k / LOAD_VEC;
|
||||
|
||||
D_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
FLOAT_TYPE cache_a[WMITER * TM];
|
||||
FLOAT_TYPE cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
||||
#ifdef ALIGNED_INPUT
|
||||
A_TYPE tmp = data_a[pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr];
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(tmp[0].x);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(tmp[0].y);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(tmp[0].z);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(tmp[0].w);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(tmp[1].x);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(tmp[1].y);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(tmp[1].z);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(tmp[1].w);
|
||||
#else
|
||||
if (ir * BM + loadc + l < p.M && block + loadr < p.K) {
|
||||
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
|
||||
} else {
|
||||
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
||||
#ifdef ALIGNED_INPUT
|
||||
B_TYPE tmp = data_b[pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr];
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(tmp[0].x);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(tmp[0].y);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(tmp[0].z);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(tmp[0].w);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(tmp[1].x);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(tmp[1].y);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(tmp[1].z);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(tmp[1].w);
|
||||
#else
|
||||
if (ic * BN + loadc + l < p.N && block + loadr < p.K) {
|
||||
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
|
||||
} else {
|
||||
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK / LOAD_VEC;
|
||||
pos_b += BK / LOAD_VEC;
|
||||
|
||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += D_TYPE(cache_a[wsir * TM + cr]) * D_TYPE(cache_b[wsic * TN + cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
const int k_split_offset = ik * p.M * p.N;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
const std::string mulmat_split_k_reduce_src = R"(
|
||||
#version 450
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer A { float data[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int k_num;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int glr = int(gl_GlobalInvocationID.x);
|
||||
const int glc = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (glr >= p.M || glc >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = glc * p.M + glr;
|
||||
|
||||
float result = 0.0f;
|
||||
|
||||
for (int i = 0; i < p.k_num; i++) {
|
||||
result += data[i * p.M * p.N + idx];
|
||||
}
|
||||
|
||||
data[idx] = result;
|
||||
}
|
||||
)";
|
||||
|
||||
// DEQUANT SHADER
|
||||
const std::string dequant_head = R"(
|
||||
#version 450
|
||||
|
@ -9,18 +242,6 @@ const std::string dequant_head = R"(
|
|||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
)";
|
||||
|
||||
const std::string dequant_glsl_fp16_ext = R"(
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
)";
|
||||
|
||||
const std::string dequant_output_fp16 = R"(
|
||||
#define OUT_TYPE float16_t
|
||||
)";
|
||||
|
||||
const std::string dequant_output_fp32 = R"(
|
||||
#define OUT_TYPE float
|
||||
)";
|
||||
|
||||
const std::string dequant_q4_0_defines = R"(
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
|
@ -83,9 +304,23 @@ const std::string mul_mat_vec_head = R"(
|
|||
#extension GL_EXT_shader_8bit_storage : require
|
||||
)";
|
||||
|
||||
const std::string mul_mat_vec_fp16 = R"(
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
const std::string mul_mat_vec_b_type_f32 = R"(
|
||||
#define B_TYPE float
|
||||
)";
|
||||
|
||||
const std::string mul_mat_vec_b_type_f16 = R"(
|
||||
#define B_TYPE float16_t
|
||||
)";
|
||||
|
||||
const std::string mul_mat_vec_f16_defines = R"(
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
#define BLOCK_SIZE 32
|
||||
|
||||
#define A_TYPE float16_t
|
||||
|
||||
#define DEQUANT_FUNC float16_t v0 = x[ib + 0]; \
|
||||
float16_t v1 = x[ib + 1];
|
||||
)";
|
||||
|
||||
const std::string mul_mat_vec_q4_0_defines = R"(
|
||||
|
@ -99,8 +334,6 @@ struct block_q4_0
|
|||
uint8_t qs[16];
|
||||
};
|
||||
#define A_TYPE block_q4_0
|
||||
#define B_TYPE float16_t
|
||||
#define OUT_TYPE float
|
||||
|
||||
#define DEQUANT_FUNC const float16_t d = x[ib].d; \
|
||||
const uint8_t vui = x[ib].qs[iqs]; \
|
||||
|
@ -122,7 +355,7 @@ layout (push_constant) uniform parameter
|
|||
int ncols;
|
||||
} p;
|
||||
|
||||
shared float16_t tmp[BLOCK_SIZE];
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const int block_size = int(gl_WorkGroupSize.x);
|
||||
|
@ -142,20 +375,20 @@ void main() {
|
|||
DEQUANT_FUNC
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||
tmp[tid] += FLOAT_TYPE(v0 * y[iybs + iqs + 0]);
|
||||
tmp[tid] += FLOAT_TYPE(v1 * y[iybs + iqs + y_offset]);
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s=block_size/2; s>0; s>>=1) {
|
||||
[[unroll]] for (int s = block_size/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = float(tmp[0]);
|
||||
dst[row] = OUT_TYPE(tmp[0]);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
@ -195,9 +428,9 @@ const std::string mul_f32_src = R"(
|
|||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X { float data_x[]; };
|
||||
layout (binding = 1) buffer Y { float data_y[]; };
|
||||
layout (binding = 2) buffer D { float data_d[]; };
|
||||
layout (binding = 0) buffer X { X_TYPE data_x[]; };
|
||||
layout (binding = 1) buffer Y { Y_TYPE data_y[]; };
|
||||
layout (binding = 2) buffer D { D_TYPE data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
|
@ -220,6 +453,76 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * data_y[p.y_offset + x];
|
||||
data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(data_y[p.y_offset + x]);
|
||||
}
|
||||
)";
|
||||
|
||||
// ADD
|
||||
const std::string add_head = R"(
|
||||
#version 450
|
||||
)";
|
||||
|
||||
const std::string add_body = R"(
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
layout (binding = 0) buffer X { X_TYPE data_x[]; };
|
||||
layout (binding = 1) buffer Y { Y_TYPE data_y[]; };
|
||||
layout (binding = 2) buffer D { D_TYPE data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int stride_x;
|
||||
int stride_y;
|
||||
int stride_d;
|
||||
int x_offset;
|
||||
int y_offset;
|
||||
int d_offset;
|
||||
float scale;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int x = int(gl_GlobalInvocationID.x);
|
||||
const int y = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (x >= p.M || y >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) + D_TYPE(data_y[p.y_offset + x]);
|
||||
}
|
||||
)";
|
||||
|
||||
// SCALE
|
||||
const std::string scale_src = R"(
|
||||
#version 450
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X { X_TYPE data_x[]; };
|
||||
layout (binding = 1) buffer D { D_TYPE data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int stride_x;
|
||||
int stride_y;
|
||||
int stride_d;
|
||||
int x_offset;
|
||||
int y_offset;
|
||||
int d_offset;
|
||||
float scale;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int x = int(gl_GlobalInvocationID.x);
|
||||
const int y = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (x >= p.M || y >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(p.scale);
|
||||
}
|
||||
)";
|
||||
|
|
153
ggml-vulkan.cpp
153
ggml-vulkan.cpp
|
@ -23,6 +23,7 @@
|
|||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
@ -165,18 +166,34 @@ vk_pipeline vk_pipeline_f32_to_f16, vk_pipeline_dequant_q4_0;
|
|||
|
||||
static std::vector<std::tuple<void*, size_t, vk_buffer>> vk_pinned_memory;
|
||||
|
||||
static std::vector<uint32_t> ggml_vk_compile_shader(const std::string& name, const std::string& src) {
|
||||
static std::vector<uint32_t> ggml_vk_compile_shader(const std::string& name, const std::string& src, std::vector<std::string>&& defines) {
|
||||
#ifdef VK_DEBUG
|
||||
std::cerr << "ggml_vk_compile_shader(" << name << ", " << src << ")" << std::endl;
|
||||
#endif
|
||||
GGML_ASSERT(defines.size() % 2 == 0);
|
||||
|
||||
shaderc::Compiler compiler;
|
||||
shaderc::CompileOptions options;
|
||||
|
||||
for (size_t i = 0; i < defines.size(); i += 2) {
|
||||
options.AddMacroDefinition(defines[i], defines[i + 1]);
|
||||
}
|
||||
|
||||
shaderc::SpvCompilationResult module = compiler.CompileGlslToSpv(src, shaderc_compute_shader, name.c_str(), options);
|
||||
|
||||
if (module.GetCompilationStatus() != shaderc_compilation_status_success) {
|
||||
shaderc::PreprocessedSourceCompilationResult prep_res = compiler.PreprocessGlsl(src, shaderc_compute_shader, name.c_str(), options);
|
||||
|
||||
std::string prep_src = std::string{ prep_res.begin(), prep_res.end() };
|
||||
|
||||
std::stringstream ss(prep_src);
|
||||
std::string line;
|
||||
int counter = 1;
|
||||
while(std::getline(ss, line, '\n')){
|
||||
std::cout << std::setw(3) << counter++ << std::setw(1) << ": " << line << std::endl;
|
||||
}
|
||||
std::cerr << "ggml_vulkan: Shader compile error in " << name << ": " << module.GetErrorMessage();
|
||||
return std::vector<uint32_t>();
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
return {module.cbegin(), module.cend()};
|
||||
|
@ -286,12 +303,12 @@ static vk_pipeline ggml_vk_create_pipeline(const std::string& name, size_t spv_s
|
|||
return pipeline;
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_create_pipeline_from_string(const std::string& name, const std::string& src, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<int>&& specialization_constants, uint32_t align) {
|
||||
static vk_pipeline ggml_vk_create_pipeline_from_string(const std::string& name, const std::string& src, std::vector<std::string>&& defines, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<int>&& specialization_constants, uint32_t align) {
|
||||
#ifdef VK_DEBUG
|
||||
std::cerr << "ggml_vk_create_pipeline_from_string(" << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl;
|
||||
#endif
|
||||
|
||||
const std::vector<uint32_t> spv = ggml_vk_compile_shader(name, src);
|
||||
const std::vector<uint32_t> spv = ggml_vk_compile_shader(name, src, std::move(defines));
|
||||
return ggml_vk_create_pipeline(name, spv.size() * sizeof(uint32_t), spv.data(), entrypoint, parameter_count, push_constant_size, wg_denoms, std::move(specialization_constants), align);
|
||||
}
|
||||
|
||||
|
@ -640,34 +657,98 @@ static void ggml_vk_generate_shaders() {
|
|||
#endif
|
||||
std::cerr << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
|
||||
|
||||
// mulmat
|
||||
auto warptile_l = { 128, 128, 128, 16, 64, 64, 2, 4, 4 };
|
||||
auto warptile_m = { 128, 64, 64, 16, 32, 32, 2, 4, 2 };
|
||||
auto warptile_s = { 32, 32, 32, 8, 32, 32, 2, 2, 2 };
|
||||
|
||||
std::stringstream stream;
|
||||
stream << mulmat_head << shader_f32 << mulmat_body;
|
||||
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline_from_string("matmul_f32_l", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline_from_string("matmul_f32_m", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline_from_string("matmul_f32_s", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_l", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_m", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_s", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
stream << mulmat_head << shader_f16 << mulmat_body;
|
||||
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline_from_string("matmul_f16_l", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline_from_string("matmul_f16_m", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline_from_string("matmul_f16_s", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
|
||||
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_l", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "f16mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_m", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "f16mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_s", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "f16mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
|
||||
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline_from_string("matmul_f16_f32_l", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_m", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_s", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_l", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_m", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_s", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
|
||||
// Build dequant q4_0
|
||||
std::stringstream stream;
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
|
||||
stream << dequant_head;
|
||||
if (vk_device.fp16) {
|
||||
stream << dequant_glsl_fp16_ext;
|
||||
}
|
||||
stream << dequant_q4_0_defines;
|
||||
if (vk_device.fp16) {
|
||||
stream << dequant_output_fp16;
|
||||
stream << shader_f16 << shader_output_f16;
|
||||
} else {
|
||||
stream << dequant_output_fp32;
|
||||
stream << shader_output_f32;
|
||||
}
|
||||
stream << dequant_body;
|
||||
stream << dequant_q4_0_defines << dequant_body;
|
||||
|
||||
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), {}, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
|
||||
|
||||
// mul mat vec
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
|
||||
stream << mul_mat_vec_head << mul_mat_vec_fp16 << mul_mat_vec_q4_0_defines << mul_mat_vec_body;
|
||||
stream << mul_mat_vec_head << shader_f16 << shader_int8_ext << shader_output_f32 << mul_mat_vec_b_type_f16 << mul_mat_vec_q4_0_defines << mul_mat_vec_body;
|
||||
|
||||
vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
|
||||
stream << mul_mat_vec_head << shader_f16 << shader_int8_ext << shader_output_f32 << mul_mat_vec_b_type_f32 << mul_mat_vec_q4_0_defines << mul_mat_vec_body;
|
||||
|
||||
vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0_f32", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
|
||||
stream << mul_mat_vec_head << shader_f16 << shader_output_f32 << mul_mat_vec_b_type_f16 << mul_mat_vec_f16_defines << mul_mat_vec_body;
|
||||
|
||||
vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
|
||||
stream << mul_mat_vec_head << shader_f16 << shader_output_f32 << mul_mat_vec_b_type_f32 << mul_mat_vec_f16_defines << mul_mat_vec_body;
|
||||
vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16_f32", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
|
||||
// add
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
|
||||
stream << add_head << add_body;
|
||||
vk_pipeline_add_f32 = ggml_vk_create_pipeline_from_string("add_f32", stream.str(), { "X_TYPE", "float", "Y_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
|
||||
stream << add_head << shader_f16 << add_body;
|
||||
vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline_from_string("add_f16_f32_f16", stream.str(), { "X_TYPE", "float16_t", "Y_TYPE", "float", "D_TYPE", "float16_t" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
|
||||
// Static shaders
|
||||
vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1);
|
||||
vk_pipeline_mul_f32 = ggml_vk_create_pipeline_from_string("mul_f32", mul_f32_src, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline_from_string("split_k_reduce", mulmat_split_k_reduce_src, {}, "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1);
|
||||
vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, {}, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1);
|
||||
vk_pipeline_mul_f32 = ggml_vk_create_pipeline_from_string("mul_f32", mul_f32_src, { "X_TYPE", "float", "Y_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
|
||||
vk_pipeline_scale_f32 = ggml_vk_create_pipeline_from_string("scale_f32", scale_src, { "X_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
}
|
||||
|
||||
void ggml_vk_test_transfer(size_t ne);
|
||||
|
@ -806,45 +887,7 @@ void ggml_vk_init(void) {
|
|||
|
||||
vk_device.descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
|
||||
|
||||
// Prepare matmul values
|
||||
auto warptile_l = { 128, 128, 128, 16, 64, 64, 2, 4, 4 };
|
||||
auto warptile_m = { 128, 64, 64, 16, 32, 32, 2, 4, 2 };
|
||||
auto warptile_s = { 32, 32, 32, 8, 32, 32, 2, 2, 2 };
|
||||
|
||||
// Shaders
|
||||
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
if (vk_device.fp16) {
|
||||
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
|
||||
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
|
||||
vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_f16.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
}
|
||||
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1);
|
||||
|
||||
vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_f16_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
|
||||
vk_pipeline_add_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/add_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline_from_file("vk_shaders/add_f16_f32_f16.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
|
||||
vk_pipeline_scale_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/scale_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
|
||||
ggml_vk_generate_shaders();
|
||||
|
||||
// Queues
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X { float16_t data_x[]; };
|
||||
layout (binding = 1) buffer Y { float data_y[]; };
|
||||
layout (binding = 2) buffer D { float16_t data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int stride_x;
|
||||
int stride_y;
|
||||
int stride_d;
|
||||
int x_offset;
|
||||
int y_offset;
|
||||
int d_offset;
|
||||
float scale;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int x = int(gl_GlobalInvocationID.x);
|
||||
const int y = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (x >= p.M || y >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] + float16_t(data_y[p.y_offset + x]);
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
#version 450
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X { float data_x[]; };
|
||||
layout (binding = 1) buffer Y { float data_y[]; };
|
||||
layout (binding = 2) buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int stride_x;
|
||||
int stride_y;
|
||||
int stride_d;
|
||||
int x_offset;
|
||||
int y_offset;
|
||||
int d_offset;
|
||||
float scale;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int x = int(gl_GlobalInvocationID.x);
|
||||
const int y = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (x >= p.M || y >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] + data_y[p.y_offset + x];
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
#define BLOCK_SIZE 32
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float16_t x[]; };
|
||||
layout (binding = 1) readonly buffer B { float16_t y[]; };
|
||||
layout (binding = 2) writeonly buffer D { float dst[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int ncols;
|
||||
} p;
|
||||
|
||||
shared float16_t tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const int block_size = int(gl_WorkGroupSize.x);
|
||||
const int row = int(gl_WorkGroupID.x);
|
||||
const int tid = int(gl_LocalInvocationID.x);
|
||||
|
||||
const int y_offset = QUANT_K/2;
|
||||
|
||||
tmp[tid] = 0.0hf;
|
||||
|
||||
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
const int ib = (row*p.ncols + col)/QUANT_K; // block index
|
||||
const int iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||
const int iybs = col - col%QUANT_K; // y block start index
|
||||
|
||||
// dequantize
|
||||
float16_t v0 = x[ib + 0];
|
||||
float16_t v1 = x[ib + 1];
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s=block_size/2; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = float(tmp[0]);
|
||||
}
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
#define BLOCK_SIZE 32
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float16_t x[]; };
|
||||
layout (binding = 1) readonly buffer B { float y[]; };
|
||||
layout (binding = 2) writeonly buffer D { float dst[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int ncols;
|
||||
} p;
|
||||
|
||||
shared float tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const int block_size = int(gl_WorkGroupSize.x);
|
||||
const int row = int(gl_WorkGroupID.x);
|
||||
const int tid = int(gl_LocalInvocationID.x);
|
||||
|
||||
const int y_offset = QUANT_K/2;
|
||||
|
||||
tmp[tid] = 0.0hf;
|
||||
|
||||
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
const int ib = (row*p.ncols + col)/QUANT_K; // block index
|
||||
const int iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||
const int iybs = col - col%QUANT_K; // y block start index
|
||||
|
||||
// dequantize
|
||||
float16_t v0 = x[ib + 0];
|
||||
float16_t v1 = x[ib + 1];
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s=block_size/2; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp[0];
|
||||
}
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
#define BLOCK_SIZE 32
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
struct block_q4_0
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
layout (binding = 0) readonly buffer A { block_q4_0 x[]; };
|
||||
layout (binding = 1) readonly buffer B { float y[]; };
|
||||
layout (binding = 2) writeonly buffer D { float dst[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int ncols;
|
||||
} p;
|
||||
|
||||
shared float tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const int block_size = int(gl_WorkGroupSize.x);
|
||||
const int row = int(gl_WorkGroupID.x);
|
||||
const int tid = int(gl_LocalInvocationID.x);
|
||||
|
||||
const int y_offset = QUANT_K/2;
|
||||
|
||||
tmp[tid] = 0.0hf;
|
||||
|
||||
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
const int ib = (row*p.ncols + col)/QUANT_K; // block index
|
||||
const int iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||
const int iybs = col - col%QUANT_K; // y block start index
|
||||
|
||||
// dequantize
|
||||
const float16_t d = x[ib].d;
|
||||
|
||||
const uint8_t vui = x[ib].qs[iqs];
|
||||
|
||||
const int8_t vi0 = int8_t(vui & 0xF);
|
||||
const int8_t vi1 = int8_t(vui >> 4);
|
||||
|
||||
float16_t v0 = float16_t(vi0 - 8)*d;
|
||||
float16_t v1 = float16_t(vi1 - 8)*d;
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s=block_size/2; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp[0];
|
||||
}
|
||||
}
|
|
@ -1,145 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float16_t data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
int k_split;
|
||||
} p;
|
||||
|
||||
layout (constant_id = 1) const int BM = 64;
|
||||
layout (constant_id = 2) const int BN = 64;
|
||||
layout (constant_id = 3) const int BK = 16;
|
||||
layout (constant_id = 4) const int WM = 32;
|
||||
layout (constant_id = 5) const int WN = 32;
|
||||
layout (constant_id = 6) const int WMITER = 2;
|
||||
layout (constant_id = 7) const int TM = 4;
|
||||
layout (constant_id = 8) const int TN = 2;
|
||||
|
||||
shared float16_t buf_a[BM * (BK+1)];
|
||||
shared float16_t buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int blocks_x = (p.M + BM - 1) / BM;
|
||||
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x);
|
||||
|
||||
const int start_k = ik * p.k_split;
|
||||
const int end_k = (ik + 1) * p.k_split;
|
||||
|
||||
int pos_a = ir * BM * p.stride_a + start_k;
|
||||
int pos_b = ic * BN * p.stride_b + start_k;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float16_t cache_a[WMITER * TM];
|
||||
float16_t cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
} else {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
} else {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK;
|
||||
pos_b += BK;
|
||||
|
||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
const int k_split_offset = ik * p.M * p.N;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,149 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { f16mat2x4 data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { f16mat2x4 data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
int k_split;
|
||||
} p;
|
||||
|
||||
layout (constant_id = 1) const int BM = 64;
|
||||
layout (constant_id = 2) const int BN = 64;
|
||||
layout (constant_id = 3) const int BK = 16;
|
||||
layout (constant_id = 4) const int WM = 32;
|
||||
layout (constant_id = 5) const int WN = 32;
|
||||
layout (constant_id = 6) const int WMITER = 2;
|
||||
layout (constant_id = 7) const int TM = 4;
|
||||
layout (constant_id = 8) const int TN = 2;
|
||||
|
||||
shared float16_t buf_a[BM * (BK+1)];
|
||||
shared float16_t buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int blocks_x = (p.M + BM - 1) / BM;
|
||||
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % (BK / 8));
|
||||
const int loadc = int(gl_LocalInvocationID.x / (BK / 8));
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x * 8) / BK;
|
||||
|
||||
const int start_k = ik * p.k_split;
|
||||
const int end_k = (ik + 1) * p.k_split;
|
||||
|
||||
int pos_a = ir * BM * p.stride_a / 8 + start_k / 8;
|
||||
int pos_b = ic * BN * p.stride_b / 8 + start_k / 8;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float16_t cache_a[WMITER * TM];
|
||||
float16_t cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
||||
f16mat2x4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 8 + loadr];
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w;
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
||||
f16mat2x4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 8 + loadr];
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK / 8;
|
||||
pos_b += BK / 8;
|
||||
|
||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
const int k_split_offset = ik * p.M * p.N;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,145 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
int k_split;
|
||||
} p;
|
||||
|
||||
layout (constant_id = 1) const int BM = 64;
|
||||
layout (constant_id = 2) const int BN = 64;
|
||||
layout (constant_id = 3) const int BK = 16;
|
||||
layout (constant_id = 4) const int WM = 32;
|
||||
layout (constant_id = 5) const int WN = 32;
|
||||
layout (constant_id = 6) const int WMITER = 2;
|
||||
layout (constant_id = 7) const int TM = 4;
|
||||
layout (constant_id = 8) const int TN = 2;
|
||||
|
||||
shared float16_t buf_a[BM * (BK+1)];
|
||||
shared float16_t buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int blocks_x = (p.M + BM - 1) / BM;
|
||||
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x);
|
||||
|
||||
const int start_k = ik * p.k_split;
|
||||
const int end_k = (ik + 1) * p.k_split;
|
||||
|
||||
int pos_a = ir * BM * p.stride_a + start_k;
|
||||
int pos_b = ic * BN * p.stride_b + start_k;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float16_t cache_a[WMITER * TM];
|
||||
float16_t cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
} else {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = float16_t(data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]);
|
||||
} else {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK;
|
||||
pos_b += BK;
|
||||
|
||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
const int k_split_offset = ik * p.M * p.N;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,149 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { f16mat2x4 data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { mat2x4 data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
int k_split;
|
||||
} p;
|
||||
|
||||
layout (constant_id = 1) const int BM = 64;
|
||||
layout (constant_id = 2) const int BN = 64;
|
||||
layout (constant_id = 3) const int BK = 16;
|
||||
layout (constant_id = 4) const int WM = 32;
|
||||
layout (constant_id = 5) const int WN = 32;
|
||||
layout (constant_id = 6) const int WMITER = 2;
|
||||
layout (constant_id = 7) const int TM = 4;
|
||||
layout (constant_id = 8) const int TN = 2;
|
||||
|
||||
shared float16_t buf_a[BM * (BK+1)];
|
||||
shared float16_t buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int blocks_x = (p.M + BM - 1) / BM;
|
||||
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % (BK / 8));
|
||||
const int loadc = int(gl_LocalInvocationID.x / (BK / 8));
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x * 8) / BK;
|
||||
|
||||
const int start_k = ik * p.k_split;
|
||||
const int end_k = (ik + 1) * p.k_split;
|
||||
|
||||
int pos_a = ir * BM * p.stride_a / 8 + start_k / 8;
|
||||
int pos_b = ic * BN * p.stride_b / 8 + start_k / 8;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float16_t cache_a[WMITER * TM];
|
||||
float16_t cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
||||
f16mat2x4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 8 + loadr];
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w;
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
||||
mat2x4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 8 + loadr];
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 0] = float16_t(tmp[0].x);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 1] = float16_t(tmp[0].y);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 2] = float16_t(tmp[0].z);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 3] = float16_t(tmp[0].w);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 4] = float16_t(tmp[1].x);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 5] = float16_t(tmp[1].y);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 6] = float16_t(tmp[1].z);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 7] = float16_t(tmp[1].w);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK / 8;
|
||||
pos_b += BK / 8;
|
||||
|
||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
const int k_split_offset = ik * p.M * p.N;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,144 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
int k_split;
|
||||
} p;
|
||||
|
||||
layout (constant_id = 1) const int BM = 64;
|
||||
layout (constant_id = 2) const int BN = 64;
|
||||
layout (constant_id = 3) const int BK = 16;
|
||||
layout (constant_id = 4) const int WM = 32;
|
||||
layout (constant_id = 5) const int WN = 32;
|
||||
layout (constant_id = 6) const int WMITER = 2;
|
||||
layout (constant_id = 7) const int TM = 4;
|
||||
layout (constant_id = 8) const int TN = 2;
|
||||
|
||||
shared float buf_a[BM * (BK+1)];
|
||||
shared float buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int blocks_x = (p.M + BM - 1) / BM;
|
||||
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x);
|
||||
|
||||
const int start_k = ik * p.k_split;
|
||||
const int end_k = (ik + 1) * p.k_split;
|
||||
|
||||
int pos_a = ir * BM * p.stride_a + start_k;
|
||||
int pos_b = ic * BN * p.stride_b + start_k;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float cache_a[WMITER * TM];
|
||||
float cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
} else {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
} else {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK;
|
||||
pos_b += BK;
|
||||
|
||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
const int k_split_offset = ik * p.M * p.N;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,140 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { vec4 data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { vec4 data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
int k_split;
|
||||
} p;
|
||||
|
||||
layout (constant_id = 1) const int BM = 64;
|
||||
layout (constant_id = 2) const int BN = 64;
|
||||
layout (constant_id = 3) const int BK = 16;
|
||||
layout (constant_id = 4) const int WM = 32;
|
||||
layout (constant_id = 5) const int WN = 32;
|
||||
layout (constant_id = 6) const int WMITER = 2;
|
||||
layout (constant_id = 7) const int TM = 4;
|
||||
layout (constant_id = 8) const int TN = 2;
|
||||
|
||||
shared float buf_a[BM * (BK+1)];
|
||||
shared float buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int blocks_x = (p.M + BM - 1) / BM;
|
||||
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % (BK / 4));
|
||||
const int loadc = int(gl_LocalInvocationID.x / (BK / 4));
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x * 4) / BK;
|
||||
|
||||
const int start_k = ik * p.k_split;
|
||||
const int end_k = (ik + 1) * p.k_split;
|
||||
|
||||
int pos_a = ir * BM * p.stride_a / 4 + start_k / 4;
|
||||
int pos_b = ic * BN * p.stride_b / 4 + start_k / 4;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float cache_a[WMITER * TM];
|
||||
float cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
||||
vec4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 4 + loadr];
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w;
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
||||
vec4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 4 + loadr];
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK / 4;
|
||||
pos_b += BK / 4;
|
||||
|
||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
const int k_split_offset = ik * p.M * p.N;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,169 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
|
||||
struct block_q4_0
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { block_q4_0 data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { vec4 data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
int k_split;
|
||||
} p;
|
||||
|
||||
layout (constant_id = 1) const int BM = 64;
|
||||
layout (constant_id = 2) const int BN = 64;
|
||||
layout (constant_id = 3) const int BK = 16;
|
||||
layout (constant_id = 4) const int WM = 32;
|
||||
layout (constant_id = 5) const int WN = 32;
|
||||
layout (constant_id = 6) const int WMITER = 2;
|
||||
layout (constant_id = 7) const int TM = 4;
|
||||
layout (constant_id = 8) const int TN = 2;
|
||||
|
||||
shared float buf_a[BM * (BK+1)];
|
||||
shared float buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int blocks_x = (p.M + BM - 1) / BM;
|
||||
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int stride_a = p.stride_a / QUANT_K;
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % (BK / 4));
|
||||
const int loadc = int(gl_LocalInvocationID.x / (BK / 4));
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x * 4) / BK;
|
||||
|
||||
const int start_k = ik * p.k_split;
|
||||
const int end_k = (ik + 1) * p.k_split;
|
||||
|
||||
int pos_b = ic * BN * p.stride_b / 4 + start_k / 4;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float cache_a[WMITER * TM];
|
||||
float cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
||||
const int row = (block + loadr * 4) / QUANT_K;
|
||||
const int qi = (block + loadr * 4) % QUANT_K;
|
||||
const block_q4_0 blk = data_a[(ir * BM + loadc + l) * stride_a + row];
|
||||
const float d = float(blk.d);
|
||||
|
||||
int x0, x1, x2, x3;
|
||||
if (qi < 16) {
|
||||
x0 = (blk.qs[qi + 0] & 0x0F) - 8;
|
||||
x1 = (blk.qs[qi + 1] & 0x0F) - 8;
|
||||
x2 = (blk.qs[qi + 2] & 0x0F) - 8;
|
||||
x3 = (blk.qs[qi + 3] & 0x0F) - 8;
|
||||
} else {
|
||||
x0 = (blk.qs[qi + 0] >> 4) - 8;
|
||||
x1 = (blk.qs[qi + 1] >> 4) - 8;
|
||||
x2 = (blk.qs[qi + 2] >> 4) - 8;
|
||||
x3 = (blk.qs[qi + 3] >> 4) - 8;
|
||||
}
|
||||
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 0] = x0*d;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 1] = x1*d;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 2] = x2*d;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 3] = x3*d;
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
||||
vec4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 4 + loadr];
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_b += BK / 4;
|
||||
|
||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
const int k_split_offset = ik * p.M * p.N;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
#version 450
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer A { float data[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int k_num;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int glr = int(gl_GlobalInvocationID.x);
|
||||
const int glc = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (glr >= p.M || glc >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = glc * p.M + glr;
|
||||
|
||||
float result = 0.0f;
|
||||
|
||||
for (int i = 0; i < p.k_num; i++) {
|
||||
result += data[i * p.M * p.N + idx];
|
||||
}
|
||||
|
||||
data[idx] = result;
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
#version 450
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X { float data_x[]; };
|
||||
layout (binding = 1) buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int stride_x;
|
||||
int stride_y;
|
||||
int stride_d;
|
||||
int x_offset;
|
||||
int y_offset;
|
||||
int d_offset;
|
||||
float scale;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int x = int(gl_GlobalInvocationID.x);
|
||||
const int y = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (x >= p.M || y >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * p.scale;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue