1D Blocktiling
This commit is contained in:
parent
baf9ff536b
commit
1b4863c2b9
2 changed files with 38 additions and 25 deletions
|
@ -1,10 +1,13 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#define BLOCKSIZE 32
|
#define BM 64
|
||||||
|
#define BN 64
|
||||||
|
#define BK 8
|
||||||
|
#define TM 8
|
||||||
|
|
||||||
#extension GL_EXT_control_flow_attributes : enable
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
layout(local_size_x = BLOCKSIZE * BLOCKSIZE, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = (BM * BN) / TM, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A { float data_a[]; };
|
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||||
layout (binding = 1) readonly buffer B { float data_b[]; };
|
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||||
|
@ -20,42 +23,50 @@ layout (push_constant) uniform parameter
|
||||||
int stride_d;
|
int stride_d;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
shared float buf_a[(BLOCKSIZE+1) * BLOCKSIZE];
|
shared float buf_a[BM * (BK+1)];
|
||||||
shared float buf_b[(BLOCKSIZE+1) * BLOCKSIZE];
|
shared float buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
void main()
|
|
||||||
{
|
|
||||||
const int lr = int(gl_LocalInvocationID.x % BLOCKSIZE);
|
|
||||||
const int lc = int(gl_LocalInvocationID.x / BLOCKSIZE);
|
|
||||||
|
|
||||||
|
void main() {
|
||||||
const int ir = int(gl_WorkGroupID.x);
|
const int ir = int(gl_WorkGroupID.x);
|
||||||
const int ic = int(gl_WorkGroupID.y);
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
int pos_a = ir * BLOCKSIZE * p.stride_a;
|
const int lr = int(gl_LocalInvocationID.x % BK);
|
||||||
int pos_b = ic * BLOCKSIZE * p.stride_b;
|
const int lc = int(gl_LocalInvocationID.x / BK);
|
||||||
|
|
||||||
float sum = 0.0f;
|
int pos_a = ir * BM * p.stride_a;
|
||||||
|
int pos_b = ic * BN * p.stride_b;
|
||||||
|
|
||||||
[[unroll]] for (int i = 0; i < p.K; i += BLOCKSIZE) {
|
float sums[TM];
|
||||||
buf_a[lc * (BLOCKSIZE+1) + lr] = data_a[pos_a + lc * p.stride_a + lr];
|
float btmp;
|
||||||
buf_b[lc * (BLOCKSIZE+1) + lr] = data_b[pos_b + lc * p.stride_b + lr];
|
|
||||||
|
[[unroll]] for (int k = 0; k < TM; k++) {
|
||||||
|
sums[k] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < p.K; i += BK) {
|
||||||
|
// A is transposed
|
||||||
|
buf_a[lc * (BK+1) + lr] = data_a[pos_a + lc * p.stride_a + lr];
|
||||||
|
buf_b[lc * (BK+1) + lr] = data_b[pos_b + lc * p.stride_b + lr];
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
pos_a += BLOCKSIZE;
|
pos_a += BK;
|
||||||
pos_b += BLOCKSIZE;
|
pos_b += BK;
|
||||||
|
|
||||||
[[unroll]] for (int j = 0; j < BLOCKSIZE; j++) {
|
[[unroll]] for (int j = 0; j < BK; j++) {
|
||||||
sum += buf_a[lr * (BLOCKSIZE+1) + j] * buf_b[lc * (BLOCKSIZE+1) + j];
|
btmp = buf_b[lc * (BK+1) + j];
|
||||||
|
[[unroll]] for (int k = 0; k < TM; k++) {
|
||||||
|
sums[k] += buf_a[(lr * TM + k) * (BK+1) + j] * btmp;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
const int dr = ir * BLOCKSIZE + lr;
|
const int dr = ir * BM + lr * TM;
|
||||||
const int dc = ic * BLOCKSIZE + lc;
|
const int dc = ic * BN + lc;
|
||||||
|
|
||||||
if (dr < p.M && dc < p.N) {
|
[[unroll]] for (int k = 0; k < TM; k++) {
|
||||||
data_d[dc * p.stride_d + dr] = sum;
|
data_d[dc * p.stride_d + dr + k] = sums[k];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,8 @@ inline static void* ggml_aligned_malloc(size_t size, size_t alignment) {
|
||||||
|
|
||||||
#define VK_API_VERSION VK_API_VERSION_1_2
|
#define VK_API_VERSION VK_API_VERSION_1_2
|
||||||
|
|
||||||
|
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
||||||
|
|
||||||
vk::Instance vk_instance;
|
vk::Instance vk_instance;
|
||||||
uint32_t vk_compute_queue_family_index;
|
uint32_t vk_compute_queue_family_index;
|
||||||
vk::PhysicalDevice vk_physical_device;
|
vk::PhysicalDevice vk_physical_device;
|
||||||
|
@ -470,7 +472,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
0,
|
0,
|
||||||
{ descriptor_set },
|
{ descriptor_set },
|
||||||
{});
|
{});
|
||||||
cmd_buffer.dispatch((ne01 + 31) / 32, (ne11 + 31) / 32, 1);
|
cmd_buffer.dispatch(CEIL_DIV(ne01, 64), CEIL_DIV(ne11, 64), 1);
|
||||||
cmd_buffer.end();
|
cmd_buffer.end();
|
||||||
|
|
||||||
vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0);
|
vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0);
|
||||||
|
@ -494,7 +496,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne);
|
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne);
|
||||||
|
|
||||||
#ifdef false
|
#if 0
|
||||||
const float * x = (float *) ((char *) src0->data);
|
const float * x = (float *) ((char *) src0->data);
|
||||||
const float * y = (float *) ((char *) src1->data);
|
const float * y = (float *) ((char *) src1->data);
|
||||||
float * d_chk = (float *) malloc(sizeof(float) * d_ne);
|
float * d_chk = (float *) malloc(sizeof(float) * d_ne);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue