2D Blocktiling
This commit is contained in:
parent
1b4863c2b9
commit
7c6860b483
2 changed files with 46 additions and 21 deletions
|
@ -1,13 +1,14 @@
|
|||
#version 450
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 8
|
||||
#define TM 8
|
||||
#define TN 8
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = (BM * BN) / TM, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||
|
@ -30,33 +31,55 @@ void main() {
|
|||
const int ir = int(gl_WorkGroupID.x);
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int lr = int(gl_LocalInvocationID.x % BK);
|
||||
const int lc = int(gl_LocalInvocationID.x / BK);
|
||||
const int lr = int(gl_LocalInvocationID.x % (BM/TM));
|
||||
const int lc = int(gl_LocalInvocationID.x / (BM/TM));
|
||||
|
||||
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x);
|
||||
|
||||
int pos_a = ir * BM * p.stride_a;
|
||||
int pos_b = ic * BN * p.stride_b;
|
||||
|
||||
float sums[TM];
|
||||
float btmp;
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
[[unroll]] for (int k = 0; k < TM; k++) {
|
||||
sums[k] = 0.0f;
|
||||
[[unroll]] for (int i = 0; i < TM*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int i = 0; i < p.K; i += BK) {
|
||||
// A is transposed
|
||||
buf_a[lc * (BK+1) + lr] = data_a[pos_a + lc * p.stride_a + lr];
|
||||
buf_b[lc * (BK+1) + lr] = data_b[pos_b + lc * p.stride_b + lr];
|
||||
for (int block = 0; block < p.K; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK;
|
||||
pos_b += BK;
|
||||
|
||||
[[unroll]] for (int j = 0; j < BK; j++) {
|
||||
btmp = buf_b[lc * (BK+1) + j];
|
||||
[[unroll]] for (int k = 0; k < TM; k++) {
|
||||
sums[k] += buf_a[(lr * TM + k) * (BK+1) + j] * btmp;
|
||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(lr * TM + j) * (BK+1) + i];
|
||||
}
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(lc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[cc * TM + cr] += cache_a[cr] * cache_b[cc];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,9 +87,11 @@ void main() {
|
|||
}
|
||||
|
||||
const int dr = ir * BM + lr * TM;
|
||||
const int dc = ic * BN + lc;
|
||||
const int dc = ic * BN + lc * TN;
|
||||
|
||||
[[unroll]] for (int k = 0; k < TM; k++) {
|
||||
data_d[dc * p.stride_d + dr + k] = sums[k];
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
data_d[(dc + cc) * p.stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -472,7 +472,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
0,
|
||||
{ descriptor_set },
|
||||
{});
|
||||
cmd_buffer.dispatch(CEIL_DIV(ne01, 64), CEIL_DIV(ne11, 64), 1);
|
||||
cmd_buffer.dispatch(CEIL_DIV(ne01, 128), CEIL_DIV(ne11, 128), 1);
|
||||
cmd_buffer.end();
|
||||
|
||||
vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue