Add WIP warp tile mat mul shaders
This commit is contained in:
parent
80b17e2f66
commit
244939029d
3 changed files with 278 additions and 6 deletions
|
@ -4,7 +4,9 @@
|
|||
#include <cblas.h>
|
||||
#include <cmath>
|
||||
#include <chrono>
|
||||
#endif
|
||||
|
||||
#ifdef VK_PROFILE
|
||||
#define PROFILE(name, block) do { \
|
||||
auto begin = std::chrono::high_resolution_clock::now(); \
|
||||
block \
|
||||
|
@ -893,9 +895,6 @@ static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct gg
|
|||
}
|
||||
if (nb0 == ts) {
|
||||
PROFILE("ggml_vk_buffer_write_2d",
|
||||
// for (uint64_t i1 = 0; i1 < ne1; i1++) {
|
||||
// ggml_vk_buffer_write(dst, offset + i1 * row_length, (uint8_t *)x + i1 * nb1, row_length, q);
|
||||
// }
|
||||
ggml_vk_buffer_write_2d(dst, offset, x, nb1, row_length, ne1, q);
|
||||
);
|
||||
return;
|
||||
|
@ -1169,9 +1168,6 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
|||
// VK_CHECK(vkSetKernelArg(*dmmv, 4, sizeof(vk_int), &ncols));
|
||||
// VK_CHECK(vkEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
|
||||
} else { // general dequantization kernel + VK matrix matrix multiplication
|
||||
// copy src1 to device
|
||||
ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1]);
|
||||
|
||||
// convert src0 to fp32 on device
|
||||
// Wait for transfers to finish
|
||||
vk_transfer_queues[0].queue.waitIdle();
|
||||
|
@ -1179,6 +1175,9 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
|||
vk_device.resetFences({ fence });
|
||||
ggml_vk_dispatch_pipeline(*to_fp32_vk, {&d_Q, &d_X}, { (int)x_ne }, { (uint32_t)x_ne, 1, 1}, cmd_buffer, fence);
|
||||
|
||||
// copy src1 to device
|
||||
ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1]);
|
||||
|
||||
// wait for conversion
|
||||
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_q_f32 src0 convert waitForFences");
|
||||
|
||||
|
|
137
vk_shaders/matmul_f16_warptile.glsl
Normal file
137
vk_shaders/matmul_f16_warptile.glsl
Normal file
|
@ -0,0 +1,137 @@
|
|||
#version 450
|
||||
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 16
|
||||
#define WM 64
|
||||
#define WN 64
|
||||
#define WMITER 4
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float16_t data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
} p;
|
||||
|
||||
shared float16_t buf_a[BM * (BK+1)];
|
||||
shared float16_t buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int ir = int(gl_WorkGroupID.x);
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x);
|
||||
|
||||
int pos_a = ir * BM * p.stride_a;
|
||||
int pos_b = ic * BN * p.stride_b;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float16_t cache_a[WMITER * TM];
|
||||
float16_t cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0hf;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = 0; block < p.K; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
} else {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
} else {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK;
|
||||
pos_b += BK;
|
||||
|
||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
136
vk_shaders/matmul_f32_warptile.glsl
Normal file
136
vk_shaders/matmul_f32_warptile.glsl
Normal file
|
@ -0,0 +1,136 @@
|
|||
#version 450
|
||||
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 16
|
||||
#define WM 64
|
||||
#define WN 64
|
||||
#define WMITER 4
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
} p;
|
||||
|
||||
shared float buf_a[BM * (BK+1)];
|
||||
shared float buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int ir = int(gl_WorkGroupID.x);
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x);
|
||||
|
||||
int pos_a = ir * BM * p.stride_a;
|
||||
int pos_b = ic * BN * p.stride_b;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float cache_a[WMITER * TM];
|
||||
float cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = 0; block < p.K; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
} else {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
} else {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK;
|
||||
pos_b += BK;
|
||||
|
||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue