Add bounds checking to matmul kernels, improve implementation, fix command buffers not freed properly

This commit is contained in:
0cc4m 2023-07-02 22:11:58 +02:00
parent 36cd5d85e9
commit 24eeb97d13
3 changed files with 392 additions and 269 deletions

File diff suppressed because it is too large Load diff

View file

@ -57,12 +57,20 @@ void main() {
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
const int lr = l % BK; const int lr = l % BK;
const int lc = l / BK; const int lc = l / BK;
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr]; buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
} else {
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
}
} }
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
const int lr = l % BK; const int lr = l % BK;
const int lc = l / BK; const int lc = l / BK;
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]; buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
} else {
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
}
} }
barrier(); barrier();
@ -94,7 +102,9 @@ void main() {
[[unroll]] for (int cc = 0; cc < TN; cc++) { [[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) { [[unroll]] for (int cr = 0; cr < TM; cr++) {
if (dr + cr*rstride < p.M && dc + cc < p.N) {
data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr]; data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr];
} }
} }
}
} }

View file

@ -56,12 +56,20 @@ void main() {
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
const int lr = l % BK; const int lr = l % BK;
const int lc = l / BK; const int lc = l / BK;
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr]; buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
} else {
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
}
} }
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
const int lr = l % BK; const int lr = l % BK;
const int lc = l / BK; const int lc = l / BK;
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]; buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
} else {
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
}
} }
barrier(); barrier();
@ -93,7 +101,9 @@ void main() {
[[unroll]] for (int cc = 0; cc < TN; cc++) { [[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) { [[unroll]] for (int cr = 0; cr < TM; cr++) {
if (dr + cr*rstride < p.M && dc + cc < p.N) {
data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr]; data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr];
} }
} }
}
} }