Write coalescing

This commit is contained in:
0cc4m 2023-06-25 09:54:40 +02:00
parent 7c6860b483
commit 0c9cca00bd

View file

@ -31,8 +31,10 @@ void main() {
const int ir = int(gl_WorkGroupID.x); const int ir = int(gl_WorkGroupID.x);
const int ic = int(gl_WorkGroupID.y); const int ic = int(gl_WorkGroupID.y);
const int lr = int(gl_LocalInvocationID.x % (BM/TM)); const int rstride = BM / TM;
const int lc = int(gl_LocalInvocationID.x / (BM/TM));
const int lr = int(gl_LocalInvocationID.x % rstride);
const int lc = int(gl_LocalInvocationID.x / rstride);
const int loadr = int(gl_LocalInvocationID.x % BK); const int loadr = int(gl_LocalInvocationID.x % BK);
const int loadc = int(gl_LocalInvocationID.x / BK); const int loadc = int(gl_LocalInvocationID.x / BK);
@ -50,7 +52,7 @@ void main() {
sums[i] = 0.0f; sums[i] = 0.0f;
} }
for (int block = 0; block < p.K; block += BK) { [[unroll]] for (int block = 0; block < p.K; block += BK) {
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
const int lr = l % BK; const int lr = l % BK;
const int lc = l / BK; const int lc = l / BK;
@ -69,8 +71,8 @@ void main() {
[[unroll]] for (int i = 0; i < BK; i++) { [[unroll]] for (int i = 0; i < BK; i++) {
// Load from shared into cache // Load from shared into cache
[[unroll]] for (int j = 0; j < TM; j++) { [[unroll]] for (int j = 0; j < BM; j++) {
cache_a[j] = buf_a[(lr * TM + j) * (BK+1) + i]; cache_a[j] = buf_a[(lr + j*rstride) * (BK+1) + i];
} }
[[unroll]] for (int j = 0; j < TN; j++) { [[unroll]] for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(lc * TN + j) * (BK+1) + i]; cache_b[j] = buf_b[(lc * TN + j) * (BK+1) + i];
@ -86,12 +88,12 @@ void main() {
barrier(); barrier();
} }
const int dr = ir * BM + lr * TM; const int dr = ir * BM + lr;
const int dc = ic * BN + lc * TN; const int dc = ic * BN + lc * TN;
[[unroll]] for (int cc = 0; cc < TN; cc++) { [[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) { [[unroll]] for (int cr = 0; cr < TM; cr++) {
data_d[(dc + cc) * p.stride_d + dr + cr] = sums[cc * TM + cr]; data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr];
} }
} }
} }