Add further ops, not yet enabled. Improve semaphore code

This commit is contained in:
0cc4m 2023-10-31 09:49:56 +01:00
parent d130fe6d6b
commit 1cb90e57e4
2 changed files with 420 additions and 313 deletions

File diff suppressed because it is too large Load diff

View file

@ -315,7 +315,7 @@ void main() {
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = 0.0f; sums[i] = 0.0f;
} }
[[unroll]] for (int block = start_k; block < end_k; block += BK) { [[unroll]] for (int block = start_k; block < end_k; block += BK) {
[[unroll]] for (int l = 0; l < BM; l += loadstride) { [[unroll]] for (int l = 0; l < BM; l += loadstride) {
@ -338,11 +338,11 @@ void main() {
#else #else
if (ir * BM + loadc + l < p.M && block + loadr < p.K) { if (ir * BM + loadc + l < p.M && block + loadr < p.K) {
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]); buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
} else { } else {
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f); buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
} }
#endif #endif
} }
[[unroll]] for (int l = 0; l < BN; l += loadstride) { [[unroll]] for (int l = 0; l < BN; l += loadstride) {
#if LOAD_VEC == 8 #if LOAD_VEC == 8
const int idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr; const int idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
@ -363,11 +363,11 @@ void main() {
#else #else
if (ic * BN + loadc + l < p.N && block + loadr < p.K) { if (ic * BN + loadc + l < p.N && block + loadr < p.K) {
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]); buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
} else { } else {
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f); buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
} }
#endif #endif
} }
barrier(); barrier();
@ -379,27 +379,27 @@ void main() {
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (int j = 0; j < TM; j++) { [[unroll]] for (int j = 0; j < TM; j++) {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
} }
} }
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (int j = 0; j < TN; j++) { [[unroll]] for (int j = 0; j < TN; j++) {
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
} }
} }
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (int cc = 0; cc < TN; cc++) { [[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) { [[unroll]] for (int cr = 0; cr < TM; cr++) {
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += D_TYPE(cache_a[wsir * TM + cr]) * D_TYPE(cache_b[wsic * TN + cc]); sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += D_TYPE(cache_a[wsir * TM + cr]) * D_TYPE(cache_b[wsic * TN + cc]);
} }
} }
} }
} }
} }
barrier(); barrier();
} }
const int dr = ir * BM + warp_r * WM; const int dr = ir * BM + warp_r * WM;
const int dc = ic * BN + warp_c * WN; const int dc = ic * BN + warp_c * WN;
@ -415,11 +415,11 @@ void main() {
[[unroll]] for (int cr = 0; cr < TM; cr++) { [[unroll]] for (int cr = 0; cr < TM; cr++) {
if (dr_warp + cr < p.M && dc_warp + cc < p.N) { if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
} }
} }
} }
} }
} }
} }
""" """
@ -442,7 +442,7 @@ void main() {
if (glr >= p.M || glc >= p.N) { if (glr >= p.M || glc >= p.N) {
return; return;
} }
const int idx = glc * p.M + glr; const int idx = glc * p.M + glr;
@ -450,7 +450,7 @@ void main() {
for (int i = 0; i < p.k_num; i++) { for (int i = 0; i < p.k_num; i++) {
result += data[i * p.M * p.N + idx]; result += data[i * p.M * p.N + idx];
} }
data[idx] = result; data[idx] = result;
} }
@ -801,19 +801,19 @@ void main() {
// matrix multiplication // matrix multiplication
tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(y[iybs + iqs + 0]); tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(y[iybs + iqs + 0]);
tmp[tid] += FLOAT_TYPE(v.y) * FLOAT_TYPE(y[iybs + iqs + y_offset]); tmp[tid] += FLOAT_TYPE(v.y) * FLOAT_TYPE(y[iybs + iqs + y_offset]);
} }
// sum up partial sums and write back result // sum up partial sums and write back result
barrier(); barrier();
[[unroll]] for (int s = block_size/2; s > 0; s >>= 1) { [[unroll]] for (int s = block_size/2; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = D_TYPE(tmp[0]); dst[row] = D_TYPE(tmp[0]);
} }
} }
""" """
@ -887,12 +887,12 @@ void main() {
[[unroll]] for (int s = 16; s > 0; s >>= 1) { [[unroll]] for (int s = 16; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = D_TYPE(tmp[0]); dst[row] = D_TYPE(tmp[0]);
} }
} }
""" """
mul_mat_vec_q3_K_body = """ mul_mat_vec_q3_K_body = """
@ -957,12 +957,12 @@ void main() {
[[unroll]] for (int s = 16; s > 0; s >>= 1) { [[unroll]] for (int s = 16; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = D_TYPE(tmp[0]); dst[row] = D_TYPE(tmp[0]);
} }
} }
""" """
mul_mat_vec_q4_K_body = """ mul_mat_vec_q4_K_body = """
@ -1076,12 +1076,12 @@ void main() {
[[unroll]] for (int s = 16; s > 0; s >>= 1) { [[unroll]] for (int s = 16; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = D_TYPE(tmp[0]); dst[row] = D_TYPE(tmp[0]);
} }
} }
""" """
mul_mat_vec_q5_K_body = """ mul_mat_vec_q5_K_body = """
@ -1191,12 +1191,12 @@ void main() {
[[unroll]] for (int s = 16; s > 0; s >>= 1) { [[unroll]] for (int s = 16; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = D_TYPE(tmp[0]); dst[row] = D_TYPE(tmp[0]);
} }
} }
""" """
mul_mat_vec_q6_K_body = """ mul_mat_vec_q6_K_body = """
@ -1276,10 +1276,10 @@ void main() {
tmp[tid] += tmp[tid + s]; tmp[tid] += tmp[tid + s];
} }
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = D_TYPE(tmp[0]); dst[row] = D_TYPE(tmp[0]);
} }
} }
""" """
@ -1307,7 +1307,7 @@ void main() {
if (row < p.K && col < p.M) { if (row < p.K && col < p.M) {
data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]); data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]);
} }
} }
""" """
@ -1339,7 +1339,7 @@ void main() {
if (x >= p.M || y >= p.N) { if (x >= p.M || y >= p.N) {
return; return;
} }
data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(data_y[p.y_offset + x]); data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(data_y[p.y_offset + x]);
} }
@ -1377,7 +1377,7 @@ void main() {
if (x >= p.M || y >= p.N) { if (x >= p.M || y >= p.N) {
return; return;
} }
data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(FLOAT_TYPE(data_x[p.x_offset + y * p.stride_x + x]) + FLOAT_TYPE(data_y[p.y_offset + x])); data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(FLOAT_TYPE(data_x[p.x_offset + y * p.stride_x + x]) + FLOAT_TYPE(data_y[p.y_offset + x]));
} }
@ -1410,12 +1410,65 @@ void main() {
if (x >= p.M || y >= p.N) { if (x >= p.M || y >= p.N) {
return; return;
} }
data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(p.scale); data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(p.scale);
} }
""" """
# GET_ROWS
get_rows_head = """#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
"""
get_rows_body = """layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
layout (binding = 0) buffer X {A_TYPE x[];};
layout (binding = 1) buffer Y {int y[];};
layout (binding = 2) buffer D {D_TYPE dst[];};
layout (push_constant) uniform parameter
{
int M;
int N;
int stride_x;
int stride_y;
int stride_d;
int x_offset;
int y_offset;
int d_offset;
float scale;
} p;
void main() {
const int col = int(gl_GlobalInvocationID.x) * 2;
const int row = int(gl_GlobalInvocationID.y);
if (col >= p.M) {
return;
}
const int r = y[row];
// copy x[r*p.M + col] to dst[row*p.M + col]
const int xi = r*p.M + col;
const int di = row*p.M + col;
const int ib = xi/QUANT_K; // block index
const int iqs = (xi%QUANT_K)/QUANT_R; // quant index
const int iybs = di - di%QUANT_K; // y block start index
const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
DEQUANT_FUNC
dst[iybs + iqs + 0] = D_TYPE(v.x);
dst[iybs + iqs + y_offset] = D_TYPE(v.y);
}
"""
GLSLC = "glslc" GLSLC = "glslc"
VK_NUM_TYPES = 16 VK_NUM_TYPES = 16
@ -1622,6 +1675,29 @@ async def main():
tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float16_t", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}, fp16)) tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float16_t", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}, fp16))
tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}, fp16)) tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}, fp16))
# get_rows
for i in range(0, VK_NUM_TYPES):
stream.clear();
stream.extend((get_rows_head, shader_int8_ext, shader_float_type))
if i == GGML_TYPE_F16:
stream.extend((shader_f16_defines, shader_f16_dequant_func_compat if not fp16 else shader_f16_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func_compat if not fp16 else shader_q4_0_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func_compat if not fp16 else shader_q4_1_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q5_0:
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func_compat if not fp16 else shader_q5_0_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q5_1:
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func_compat if not fp16 else shader_q5_1_dequant_func, get_rows_body))
elif i == GGML_TYPE_Q8_0:
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func_compat if not fp16 else shader_q8_0_dequant_func, get_rows_body))
else:
continue
tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float16_t"}, fp16))
tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float"}, fp16))
# add # add
stream.clear(); stream.clear();