Add bounds checking to matmul kernels, improve implementation, fix command buffers not freed properly
This commit is contained in:
parent
36cd5d85e9
commit
24eeb97d13
3 changed files with 392 additions and 269 deletions
629
ggml-vulkan.cpp
629
ggml-vulkan.cpp
File diff suppressed because it is too large
Load diff
|
@ -57,12 +57,20 @@ void main() {
|
|||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
} else {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
} else {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
@ -94,7 +102,9 @@ void main() {
|
|||
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr];
|
||||
if (dr + cr*rstride < p.M && dc + cc < p.N) {
|
||||
data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,12 +56,20 @@ void main() {
|
|||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
} else {
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||
}
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
} else {
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
@ -93,7 +101,9 @@ void main() {
|
|||
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr];
|
||||
if (dr + cr*rstride < p.M && dc + cc < p.N) {
|
||||
data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue