1D Blocktiling

This commit is contained in:
0cc4m 2023-06-24 08:01:43 +02:00
parent baf9ff536b
commit 1b4863c2b9
2 changed files with 38 additions and 25 deletions

View file

@ -1,10 +1,13 @@
#version 450 #version 450
#define BLOCKSIZE 32 #define BM 64
#define BN 64
#define BK 8
#define TM 8
#extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = BLOCKSIZE * BLOCKSIZE, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = (BM * BN) / TM, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { float data_a[]; }; layout (binding = 0) readonly buffer A { float data_a[]; };
layout (binding = 1) readonly buffer B { float data_b[]; }; layout (binding = 1) readonly buffer B { float data_b[]; };
@ -20,42 +23,50 @@ layout (push_constant) uniform parameter
int stride_d; int stride_d;
} p; } p;
shared float buf_a[(BLOCKSIZE+1) * BLOCKSIZE]; shared float buf_a[BM * (BK+1)];
shared float buf_b[(BLOCKSIZE+1) * BLOCKSIZE]; shared float buf_b[BN * (BK+1)];
void main()
{
const int lr = int(gl_LocalInvocationID.x % BLOCKSIZE);
const int lc = int(gl_LocalInvocationID.x / BLOCKSIZE);
void main() {
const int ir = int(gl_WorkGroupID.x); const int ir = int(gl_WorkGroupID.x);
const int ic = int(gl_WorkGroupID.y); const int ic = int(gl_WorkGroupID.y);
int pos_a = ir * BLOCKSIZE * p.stride_a; const int lr = int(gl_LocalInvocationID.x % BK);
int pos_b = ic * BLOCKSIZE * p.stride_b; const int lc = int(gl_LocalInvocationID.x / BK);
float sum = 0.0f; int pos_a = ir * BM * p.stride_a;
int pos_b = ic * BN * p.stride_b;
[[unroll]] for (int i = 0; i < p.K; i += BLOCKSIZE) { float sums[TM];
buf_a[lc * (BLOCKSIZE+1) + lr] = data_a[pos_a + lc * p.stride_a + lr]; float btmp;
buf_b[lc * (BLOCKSIZE+1) + lr] = data_b[pos_b + lc * p.stride_b + lr];
[[unroll]] for (int k = 0; k < TM; k++) {
sums[k] = 0.0f;
}
for (int i = 0; i < p.K; i += BK) {
// A is transposed
buf_a[lc * (BK+1) + lr] = data_a[pos_a + lc * p.stride_a + lr];
buf_b[lc * (BK+1) + lr] = data_b[pos_b + lc * p.stride_b + lr];
barrier(); barrier();
pos_a += BLOCKSIZE; pos_a += BK;
pos_b += BLOCKSIZE; pos_b += BK;
[[unroll]] for (int j = 0; j < BLOCKSIZE; j++) { [[unroll]] for (int j = 0; j < BK; j++) {
sum += buf_a[lr * (BLOCKSIZE+1) + j] * buf_b[lc * (BLOCKSIZE+1) + j]; btmp = buf_b[lc * (BK+1) + j];
[[unroll]] for (int k = 0; k < TM; k++) {
sums[k] += buf_a[(lr * TM + k) * (BK+1) + j] * btmp;
}
} }
barrier(); barrier();
} }
const int dr = ir * BLOCKSIZE + lr; const int dr = ir * BM + lr * TM;
const int dc = ic * BLOCKSIZE + lc; const int dc = ic * BN + lc;
if (dr < p.M && dc < p.N) { [[unroll]] for (int k = 0; k < TM; k++) {
data_d[dc * p.stride_d + dr] = sum; data_d[dc * p.stride_d + dr + k] = sums[k];
} }
} }

View file

@ -38,6 +38,8 @@ inline static void* ggml_aligned_malloc(size_t size, size_t alignment) {
#define VK_API_VERSION VK_API_VERSION_1_2 #define VK_API_VERSION VK_API_VERSION_1_2
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
vk::Instance vk_instance; vk::Instance vk_instance;
uint32_t vk_compute_queue_family_index; uint32_t vk_compute_queue_family_index;
vk::PhysicalDevice vk_physical_device; vk::PhysicalDevice vk_physical_device;
@ -470,7 +472,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
0, 0,
{ descriptor_set }, { descriptor_set },
{}); {});
cmd_buffer.dispatch((ne01 + 31) / 32, (ne11 + 31) / 32, 1); cmd_buffer.dispatch(CEIL_DIV(ne01, 64), CEIL_DIV(ne11, 64), 1);
cmd_buffer.end(); cmd_buffer.end();
vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0); vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0);
@ -494,7 +496,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne); ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne);
#ifdef false #if 0
const float * x = (float *) ((char *) src0->data); const float * x = (float *) ((char *) src0->data);
const float * y = (float *) ((char *) src1->data); const float * y = (float *) ((char *) src1->data);
float * d_chk = (float *) malloc(sizeof(float) * d_ne); float * d_chk = (float *) malloc(sizeof(float) * d_ne);