Optimize warptile matmul shader, replace blocktile with it
This commit is contained in:
parent
6d5a0ada8c
commit
c3d947510b
5 changed files with 108 additions and 327 deletions
|
@ -439,9 +439,9 @@ void ggml_vk_init(void) {
|
|||
vmaCreateAllocator(&allocator_info, &vk_allocator);
|
||||
|
||||
// Shaders
|
||||
vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 6, {128, 128, 1});
|
||||
vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 6, {64, 64, 1});
|
||||
if (vk_fp16_support) {
|
||||
vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 6, {128, 128, 1});
|
||||
vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 6, {64, 64, 1});
|
||||
}
|
||||
|
||||
vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1});
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
#version 450
|
||||
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 8
|
||||
#define TM 8
|
||||
#define TN 8
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 16
|
||||
#define WM 32
|
||||
#define WN 32
|
||||
#define WMITER 2
|
||||
#define TM 4
|
||||
#define TN 2
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float16_t data_b[]; };
|
||||
|
@ -32,10 +37,17 @@ void main() {
|
|||
const int ir = int(gl_WorkGroupID.x);
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int rstride = BM / TM;
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int lr = int(gl_LocalInvocationID.x % rstride);
|
||||
const int lc = int(gl_LocalInvocationID.x / rstride);
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
@ -45,11 +57,11 @@ void main() {
|
|||
int pos_a = ir * BM * p.stride_a;
|
||||
int pos_b = ic * BN * p.stride_b;
|
||||
|
||||
float sums[TM * TN];
|
||||
float16_t cache_a[TM];
|
||||
float16_t cache_b[TN];
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float16_t cache_a[WMITER * TM];
|
||||
float16_t cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < TM*TN; i++) {
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0hf;
|
||||
}
|
||||
|
||||
|
@ -80,16 +92,24 @@ void main() {
|
|||
|
||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int j = 0; j < BM; j++) {
|
||||
cache_a[j] = buf_a[(lr + j*rstride) * (BK+1) + i];
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(lc * TN + j) * (BK+1) + i];
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[cc * TM + cr] += float(cache_a[cr]) * float(cache_b[cc]);
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -97,13 +117,20 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + lr;
|
||||
const int dc = ic * BN + lc * TN;
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr*rstride < p.M && dc + cc < p.N) {
|
||||
data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr];
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,137 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 16
|
||||
#define WM 64
|
||||
#define WN 64
|
||||
#define WMITER 4
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float16_t data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
} p;
|
||||
|
||||
shared float16_t buf_a[BM * (BK+1)];
|
||||
shared float16_t buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int ir = int(gl_WorkGroupID.x);
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x);
|
||||
|
||||
int pos_a = ir * BM * p.stride_a;
|
||||
int pos_b = ic * BN * p.stride_b;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float16_t cache_a[WMITER * TM];
|
||||
float16_t cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0hf;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = 0; block < p.K; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
} else {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
} else {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK;
|
||||
pos_b += BK;
|
||||
|
||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,14 +1,19 @@
|
|||
#version 450
|
||||
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 8
|
||||
#define TM 8
|
||||
#define TN 8
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 16
|
||||
#define WM 32
|
||||
#define WN 32
|
||||
#define WMITER 2
|
||||
#define TM 4
|
||||
#define TN 2
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||
|
@ -31,10 +36,17 @@ void main() {
|
|||
const int ir = int(gl_WorkGroupID.x);
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int rstride = BM / TM;
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int lr = int(gl_LocalInvocationID.x % rstride);
|
||||
const int lc = int(gl_LocalInvocationID.x / rstride);
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
@ -44,11 +56,11 @@ void main() {
|
|||
int pos_a = ir * BM * p.stride_a;
|
||||
int pos_b = ic * BN * p.stride_b;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float cache_a[WMITER * TM];
|
||||
float cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < TM*TN; i++) {
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
|
@ -79,16 +91,24 @@ void main() {
|
|||
|
||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int j = 0; j < BM; j++) {
|
||||
cache_a[j] = buf_a[(lr + j*rstride) * (BK+1) + i];
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(lc * TN + j) * (BK+1) + i];
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[cc * TM + cr] += cache_a[cr] * cache_b[cc];
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -96,13 +116,20 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + lr;
|
||||
const int dc = ic * BN + lc * TN;
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr*rstride < p.M && dc + cc < p.N) {
|
||||
data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr];
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,136 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 16
|
||||
#define WM 64
|
||||
#define WN 64
|
||||
#define WMITER 4
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
#define WARP 32
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
int stride_d;
|
||||
} p;
|
||||
|
||||
shared float buf_a[BM * (BK+1)];
|
||||
shared float buf_b[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int ir = int(gl_WorkGroupID.x);
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||
const int warp_r = warp_i % (BM / WM);
|
||||
const int warp_c = warp_i / (BM / WM);
|
||||
|
||||
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const int WSUBM = WM / WMITER;
|
||||
const int WSUBN = WN / WNITER;
|
||||
|
||||
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||
const int tiwr = tiw % (WSUBM / TM);
|
||||
const int tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x);
|
||||
|
||||
int pos_a = ir * BM * p.stride_a;
|
||||
int pos_b = ic * BN * p.stride_b;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
float cache_a[WMITER * TM];
|
||||
float cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = 0; block < p.K; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
} else {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
} else {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK;
|
||||
pos_b += BK;
|
||||
|
||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + warp_r * WM;
|
||||
const int dc = ic * BN + warp_c * WN;
|
||||
|
||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue