1D Blocktiling

This commit is contained in:
0cc4m 2023-06-24 08:01:43 +02:00
parent baf9ff536b
commit 1b4863c2b9
2 changed files with 38 additions and 25 deletions

View file

@ -1,10 +1,13 @@
#version 450
#define BLOCKSIZE 32
#define BM 64
#define BN 64
#define BK 8
#define TM 8
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = BLOCKSIZE * BLOCKSIZE, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x = (BM * BN) / TM, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { float data_a[]; };
layout (binding = 1) readonly buffer B { float data_b[]; };
@ -20,42 +23,50 @@ layout (push_constant) uniform parameter
int stride_d;
} p;
shared float buf_a[(BLOCKSIZE+1) * BLOCKSIZE];
shared float buf_b[(BLOCKSIZE+1) * BLOCKSIZE];
void main()
{
const int lr = int(gl_LocalInvocationID.x % BLOCKSIZE);
const int lc = int(gl_LocalInvocationID.x / BLOCKSIZE);
shared float buf_a[BM * (BK+1)];
shared float buf_b[BN * (BK+1)];
void main() {
const int ir = int(gl_WorkGroupID.x);
const int ic = int(gl_WorkGroupID.y);
int pos_a = ir * BLOCKSIZE * p.stride_a;
int pos_b = ic * BLOCKSIZE * p.stride_b;
const int lr = int(gl_LocalInvocationID.x % BK);
const int lc = int(gl_LocalInvocationID.x / BK);
float sum = 0.0f;
int pos_a = ir * BM * p.stride_a;
int pos_b = ic * BN * p.stride_b;
[[unroll]] for (int i = 0; i < p.K; i += BLOCKSIZE) {
buf_a[lc * (BLOCKSIZE+1) + lr] = data_a[pos_a + lc * p.stride_a + lr];
buf_b[lc * (BLOCKSIZE+1) + lr] = data_b[pos_b + lc * p.stride_b + lr];
float sums[TM];
float btmp;
[[unroll]] for (int k = 0; k < TM; k++) {
sums[k] = 0.0f;
}
for (int i = 0; i < p.K; i += BK) {
// A is transposed
buf_a[lc * (BK+1) + lr] = data_a[pos_a + lc * p.stride_a + lr];
buf_b[lc * (BK+1) + lr] = data_b[pos_b + lc * p.stride_b + lr];
barrier();
pos_a += BLOCKSIZE;
pos_b += BLOCKSIZE;
pos_a += BK;
pos_b += BK;
[[unroll]] for (int j = 0; j < BLOCKSIZE; j++) {
sum += buf_a[lr * (BLOCKSIZE+1) + j] * buf_b[lc * (BLOCKSIZE+1) + j];
[[unroll]] for (int j = 0; j < BK; j++) {
btmp = buf_b[lc * (BK+1) + j];
[[unroll]] for (int k = 0; k < TM; k++) {
sums[k] += buf_a[(lr * TM + k) * (BK+1) + j] * btmp;
}
}
barrier();
}
const int dr = ir * BLOCKSIZE + lr;
const int dc = ic * BLOCKSIZE + lc;
const int dr = ir * BM + lr * TM;
const int dc = ic * BN + lc;
if (dr < p.M && dc < p.N) {
data_d[dc * p.stride_d + dr] = sum;
[[unroll]] for (int k = 0; k < TM; k++) {
data_d[dc * p.stride_d + dr + k] = sums[k];
}
}

View file

@ -38,6 +38,8 @@ inline static void* ggml_aligned_malloc(size_t size, size_t alignment) {
#define VK_API_VERSION VK_API_VERSION_1_2
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
vk::Instance vk_instance;
uint32_t vk_compute_queue_family_index;
vk::PhysicalDevice vk_physical_device;
@ -470,7 +472,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
0,
{ descriptor_set },
{});
cmd_buffer.dispatch((ne01 + 31) / 32, (ne11 + 31) / 32, 1);
cmd_buffer.dispatch(CEIL_DIV(ne01, 64), CEIL_DIV(ne11, 64), 1);
cmd_buffer.end();
vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0);
@ -494,7 +496,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne);
#ifdef false
#if 0
const float * x = (float *) ((char *) src0->data);
const float * y = (float *) ((char *) src1->data);
float * d_chk = (float *) malloc(sizeof(float) * d_ne);