vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and flash attention (#10206)
This commit is contained in:
		
							parent
							
								
									6fe6247831
								
							
						
					
					
						commit
						c9c6e01dae
					
				
					 6 changed files with 1665 additions and 97 deletions
				
			
		|  | @ -1,7 +1,9 @@ | |||
| find_package (Threads REQUIRED) | ||||
| find_package(Vulkan COMPONENTS glslc REQUIRED) | ||||
| 
 | ||||
| set(TARGET vulkan-shaders-gen) | ||||
| add_executable(${TARGET} vulkan-shaders-gen.cpp) | ||||
| install(TARGETS ${TARGET} RUNTIME) | ||||
| target_compile_features(${TARGET} PRIVATE cxx_std_17) | ||||
| target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) | ||||
| target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan) | ||||
|  |  | |||
							
								
								
									
										305
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										305
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,305 @@ | |||
| 
 | ||||
| #include "types.comp" | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { | ||||
|    block_q4_0_packed16 block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
|     float16_t ret = (float16_t(qs) - float16_t(8)) * d; | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { | ||||
|    block_q4_1 block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const float16_t m = bl.block.m; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx & 0xF; | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = bl.block.qs[iqs]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
|     float16_t ret = float16_t(qs) * d + m; | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { | ||||
|    block_q5_0 block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx & 0xF; | ||||
| 
 | ||||
|     const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; | ||||
|     const uint qh = ((uint_qh >> idx) << 4) & 0x10; | ||||
| 
 | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = bl.block.qs[iqs]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
| 
 | ||||
|     float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { | ||||
|    block_q5_1 block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const float16_t m = bl.block.m; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx & 0xF; | ||||
| 
 | ||||
|     const uint uint_qh = bl.block.qh; | ||||
|     const uint qh = ((uint_qh >> idx) << 4) & 0x10; | ||||
| 
 | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = bl.block.qs[iqs]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
| 
 | ||||
|     float16_t ret = float16_t(qs | qh) * d + m; | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { | ||||
|    block_q8_0_packed16 block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
| 
 | ||||
|     // Load 16b and select the byte for this element | ||||
|     int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; | ||||
|     float16_t ret = float16_t(qs) * d; | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { | ||||
|    block_q2_K block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const f16vec2 d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
| 
 | ||||
|     const uint qsi = (iqs / 128) * 32 + (iqs % 32);     // 0..31 | ||||
|     const uint scalesi = iqs / 16;                      // 0..15 | ||||
|     const uint qsshift = ((iqs % 128) / 32) * 2;        // 0,2,4,6 | ||||
| 
 | ||||
|     uint32_t qs = bl.block.qs[qsi]; | ||||
|     const uint scales = bl.block.scales[scalesi]; | ||||
|     float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { | ||||
|    block_q3_K block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
| 
 | ||||
|     const uint n = iqs / 128;                    // 0,1 | ||||
|     const uint qsi = n * 32 + (iqs % 32);        // 0..63 | ||||
|     const uint hmi =          (iqs % 32);        // 0..31 | ||||
|     const uint j = (iqs % 128) / 8;              // 0..15 | ||||
|     const uint is = iqs / 16;                    // 0..15 | ||||
|     const uint halfsplit = ((iqs % 128) / 32);   // 0,1,2,3 | ||||
|     const uint qsshift = halfsplit * 2;          // 0,2,4,6 | ||||
|     const uint m = 1 << (4 * n + halfsplit);     // 1,2,4,8,16,32,64,128 | ||||
| 
 | ||||
|     uint32_t scaleidx0 = (is < 8) ? is : (is-8); | ||||
|     uint32_t scaleidx0shift = (is < 8) ? 0 : 4; | ||||
|     uint32_t scaleidx1 = is + 8 - (is/4)*4; | ||||
|     uint32_t scaleidx1shift = (is/4)*2; | ||||
| 
 | ||||
|     const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); | ||||
| 
 | ||||
|     const float16_t dl = bl.block.d * float16_t(us - 32); | ||||
| 
 | ||||
|     float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi    ] >> qsshift) & 3) - (((bl.block.hmask[hmi    ] & m) != 0) ? 0 : 4)); | ||||
| 
 | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { | ||||
|    block_q4_K block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
| 
 | ||||
|     const uint n = iqs / 64;                   // 0,1,2,3 | ||||
|     const uint b = (iqs % 64) / 32;            // 0,1 | ||||
|     const uint is = (idx & 0xE0) >> 5;         // 0..7 | ||||
|     const uint qsi = n * 32 + (iqs % 32);      // 0..127 | ||||
| 
 | ||||
|     const f16vec2 loadd = bl.block.d; | ||||
| 
 | ||||
|     uint32_t sc; | ||||
|     uint32_t mbyte; | ||||
| 
 | ||||
|     uint32_t scidx0 = (is < 4) ? is : (is + 4); | ||||
|     uint32_t scidx1 = (is < 4) ? is : (is - 4); | ||||
|     uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||
|     uint32_t scidxshift1 = (is < 4) ? 0 : 2; | ||||
|     uint32_t mbidx0 = is + 4; | ||||
|     uint32_t mbidx1 = (is < 4) ? is + 4 : is; | ||||
|     uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; | ||||
|     uint32_t mbidxshift0 = (is < 4) ? 0 : 4; | ||||
|     uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||
|     uint32_t mbidxshift1 = (is < 4) ? 0 : 2; | ||||
| 
 | ||||
|     sc    = uint8_t((bl.block.scales[scidx0] & 0xF)                         | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); | ||||
|     mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); | ||||
| 
 | ||||
|     const float16_t d = loadd.x * float16_t(sc); | ||||
|     const float16_t m = loadd.y * float16_t(mbyte); | ||||
| 
 | ||||
|     uint32_t dmask = 0xF << (b * 4); | ||||
| 
 | ||||
|     float16_t ret = d * float16_t((bl.block.qs[qsi    ] & dmask) >> (b * 4)) - m; | ||||
| 
 | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { | ||||
|    block_q5_K block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
| 
 | ||||
|     const uint n = iqs / 64;                   // 0,1,2,3 | ||||
|     const uint b = (iqs % 64) / 32;            // 0,1 | ||||
|     const uint is = (idx & 0xE0) >> 5;         // 0..7 | ||||
|     const uint qsi = n * 32 + (iqs % 32);      // 0..127 | ||||
|     const uint qhi = (iqs % 32);               // 0..31 | ||||
| 
 | ||||
|     const uint8_t hm = uint8_t(1 << (iqs / 32)); | ||||
| 
 | ||||
|     const f16vec2 loadd = bl.block.d; | ||||
| 
 | ||||
|     uint32_t sc; | ||||
|     uint32_t mbyte; | ||||
| 
 | ||||
|     uint32_t scidx0 = (is < 4) ? is : (is + 4); | ||||
|     uint32_t scidx1 = (is < 4) ? is : (is - 4); | ||||
|     uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||
|     uint32_t scidxshift1 = (is < 4) ? 0 : 2; | ||||
|     uint32_t mbidx0 = is + 4; | ||||
|     uint32_t mbidx1 = (is < 4) ? is + 4 : is; | ||||
|     uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; | ||||
|     uint32_t mbidxshift0 = (is < 4) ? 0 : 4; | ||||
|     uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||
|     uint32_t mbidxshift1 = (is < 4) ? 0 : 2; | ||||
| 
 | ||||
|     sc    = uint8_t((bl.block.scales[scidx0] & 0xF)                         | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); | ||||
|     mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); | ||||
| 
 | ||||
|     const float16_t d = loadd.x * float16_t(sc); | ||||
|     const float16_t m = loadd.y * float16_t(mbyte); | ||||
| 
 | ||||
|     uint32_t dmask = 0xF << (b * 4); | ||||
| 
 | ||||
|     float16_t ret = d * (float16_t((bl.block.qs[qsi    ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi    ] & hm) != 0 ? 16 : 0)) - m; | ||||
| 
 | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { | ||||
|    block_q6_K block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
| 
 | ||||
|     const uint n = iqs / 128;                   // 0,1 | ||||
|     const uint b = (iqs % 128) / 64;            // 0,1 | ||||
|     const uint is_b = (iqs % 32) / 16;          // 0,1 | ||||
|     const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6 | ||||
|     const uint is = 8 * n + qhshift + is_b;     // 0..15 | ||||
|     const uint qsi = n * 64 + (iqs % 64);       // 0..127 | ||||
|     const uint qhi = n * 32 + (iqs % 32);       // 0..63 | ||||
| 
 | ||||
|     const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); | ||||
| 
 | ||||
|     float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi    ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi    ] >> qhshift) & 3) << 4)) - 32); | ||||
| 
 | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| #if defined(DATA_A_IQ4_NL) | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { | ||||
|    block_iq4_nl block; | ||||
| }; | ||||
| 
 | ||||
| float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx & 0xF; | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = bl.block.qs[iqs]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
|     float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; | ||||
|     return ret; | ||||
| } | ||||
| #endif | ||||
| 
 | ||||
| #if defined(DATA_A_Q4_0) | ||||
| #define dequantFuncA dequantFuncQ4_0 | ||||
| #elif defined(DATA_A_Q4_1) | ||||
| #define dequantFuncA dequantFuncQ4_1 | ||||
| #elif defined(DATA_A_Q5_0) | ||||
| #define dequantFuncA dequantFuncQ5_0 | ||||
| #elif defined(DATA_A_Q5_1) | ||||
| #define dequantFuncA dequantFuncQ5_1 | ||||
| #elif defined(DATA_A_Q8_0) | ||||
| #define dequantFuncA dequantFuncQ8_0 | ||||
| #elif defined(DATA_A_Q2_K) | ||||
| #define dequantFuncA dequantFuncQ2_K | ||||
| #elif defined(DATA_A_Q3_K) | ||||
| #define dequantFuncA dequantFuncQ3_K | ||||
| #elif defined(DATA_A_Q4_K) | ||||
| #define dequantFuncA dequantFuncQ4_K | ||||
| #elif defined(DATA_A_Q5_K) | ||||
| #define dequantFuncA dequantFuncQ5_K | ||||
| #elif defined(DATA_A_Q6_K) | ||||
| #define dequantFuncA dequantFuncQ6_K | ||||
| #elif defined(DATA_A_IQ4_NL) | ||||
| #define dequantFuncA dequantFuncIQ4_NL | ||||
| #endif | ||||
							
								
								
									
										289
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										289
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,289 @@ | |||
| #version 450 | ||||
| 
 | ||||
| #extension GL_EXT_control_flow_attributes : enable | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
| 
 | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require | ||||
| 
 | ||||
| #extension GL_KHR_memory_scope_semantics : enable | ||||
| #extension GL_KHR_cooperative_matrix : enable | ||||
| #extension GL_NV_cooperative_matrix2 : enable | ||||
| #extension GL_EXT_buffer_reference : enable | ||||
| #extension GL_KHR_shader_subgroup_ballot : enable | ||||
| #extension GL_KHR_shader_subgroup_vote : enable | ||||
| #extension GL_EXT_null_initializer : enable | ||||
| 
 | ||||
| #include "types.comp" | ||||
| #include "dequant_funcs_cm2.comp" | ||||
| 
 | ||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||
| 
 | ||||
| layout (constant_id = 1) const uint32_t Br = 32; | ||||
| layout (constant_id = 2) const uint32_t Bc = 32; | ||||
| layout (constant_id = 3) const uint32_t D = 32; | ||||
| layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; | ||||
| 
 | ||||
| layout (push_constant) uniform parameter { | ||||
|     uint32_t N; | ||||
|     uint32_t KV; | ||||
| 
 | ||||
|     uint32_t ne1; | ||||
|     uint32_t ne2; | ||||
|     uint32_t ne3; | ||||
| 
 | ||||
|     uint32_t neq2; | ||||
|     uint32_t neq3; | ||||
|     uint32_t nek2; | ||||
|     uint32_t nek3; | ||||
|     uint32_t nev2; | ||||
|     uint32_t nev3; | ||||
|     uint32_t nem1; | ||||
| 
 | ||||
|     uint32_t nb02; | ||||
|     uint32_t nb03; | ||||
|     uint32_t nb12; | ||||
|     uint32_t nb13; | ||||
|     uint32_t nb22; | ||||
|     uint32_t nb23; | ||||
|     uint32_t nb31; | ||||
| 
 | ||||
|     float scale; | ||||
|     float max_bias; | ||||
|     float logit_softcap; | ||||
| 
 | ||||
|     uint32_t mask; | ||||
|     uint32_t n_head_log2; | ||||
|     float m0; | ||||
|     float m1; | ||||
| } p; | ||||
| 
 | ||||
| layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; | ||||
| layout (binding = 1) readonly buffer K {uint8_t data_k[];}; | ||||
| layout (binding = 2) readonly buffer V {uint8_t data_v[];}; | ||||
| layout (binding = 3) readonly buffer M {uint8_t data_m[];}; | ||||
| layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; | ||||
| 
 | ||||
| #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) | ||||
| 
 | ||||
| ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { | ||||
|     return max(x, y); | ||||
| } | ||||
| 
 | ||||
| ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { | ||||
|     return x; | ||||
| } | ||||
| 
 | ||||
| // Replace matrix elements >= numRows or numCols with 'replace' | ||||
| ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { | ||||
|     if (row >= numRows || col >= numCols) { | ||||
|         return replace; | ||||
|     } | ||||
|     return elem; | ||||
| } | ||||
| 
 | ||||
| ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) | ||||
| { | ||||
|     return exp(elem); | ||||
| } | ||||
| 
 | ||||
| ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) | ||||
| { | ||||
|     return max(elem0, elem1); | ||||
| } | ||||
| 
 | ||||
| #if defined(BLOCK_SIZE) | ||||
| #define DECODEFUNC , DEQUANTFUNC | ||||
| #else | ||||
| #define DECODEFUNC | ||||
| #endif | ||||
| 
 | ||||
| void main() { | ||||
| #if defined(DATA_A_IQ4_NL) | ||||
|     init_iq4nl_shmem(); | ||||
| #endif | ||||
| 
 | ||||
|     const uint32_t N = p.N; | ||||
|     const uint32_t KV = p.KV; | ||||
| 
 | ||||
|     const uint32_t Tr = CEIL_DIV(N, Br); | ||||
|     const uint32_t Tc = CEIL_DIV(KV, Bc); | ||||
| 
 | ||||
|     const uint32_t i = gl_WorkGroupID.x; | ||||
| 
 | ||||
|     const uint32_t iq2 = gl_WorkGroupID.y; | ||||
|     const uint32_t iq3 = gl_WorkGroupID.z; | ||||
| 
 | ||||
|     // broadcast factors | ||||
|     const uint32_t rk2 = p.neq2/p.nek2; | ||||
|     const uint32_t rk3 = p.neq3/p.nek3; | ||||
| 
 | ||||
|     const uint32_t rv2 = p.neq2/p.nev2; | ||||
|     const uint32_t rv3 = p.neq3/p.nev3; | ||||
| 
 | ||||
|     // k indices | ||||
|     const uint32_t ik3 = iq3 / rk3; | ||||
|     const uint32_t ik2 = iq2 / rk2; | ||||
| 
 | ||||
|     // v indices | ||||
|     const uint32_t iv3 = iq3 / rv3; | ||||
|     const uint32_t iv2 = iq2 / rv2; | ||||
| 
 | ||||
|     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|     tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); | ||||
|     tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); | ||||
| 
 | ||||
|     tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); | ||||
| 
 | ||||
| #if defined(BLOCK_SIZE) | ||||
|     tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); | ||||
|     tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); | ||||
| #endif | ||||
| 
 | ||||
|     tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); | ||||
|     tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); | ||||
|     tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); | ||||
| 
 | ||||
|     coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q; | ||||
|     coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16; | ||||
| 
 | ||||
|     uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; | ||||
|     coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); | ||||
| 
 | ||||
|     Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q); | ||||
|     Qf16 *= float16_t(p.scale); | ||||
| 
 | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0); | ||||
| 
 | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; | ||||
| 
 | ||||
|     L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); | ||||
|     M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0); | ||||
| 
 | ||||
|     ACC_TYPE slope = ACC_TYPE(1.0); | ||||
| 
 | ||||
|     // ALiBi | ||||
|     if (p.max_bias > 0.0f) { | ||||
|         const uint32_t h = iq2; | ||||
| 
 | ||||
|         const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); | ||||
|         const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); | ||||
| 
 | ||||
|         slope = pow(base, ACC_TYPE(exph)); | ||||
|     } | ||||
| 
 | ||||
|     [[dont_unroll]] | ||||
|     for (uint32_t j = 0; j < Tc; ++j) { | ||||
| 
 | ||||
|         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); | ||||
| 
 | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T; | ||||
| 
 | ||||
|         uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; | ||||
|         coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); | ||||
|         S = coopMatMulAdd(Qf16, K_T, S); | ||||
| 
 | ||||
|         if (p.logit_softcap != 0.0f) { | ||||
|             [[unroll]] | ||||
|             for (int k = 0; k < S.length(); ++k) { | ||||
|                 S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if (p.mask != 0) { | ||||
|             tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); | ||||
| 
 | ||||
|             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; | ||||
| 
 | ||||
|             coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); | ||||
| 
 | ||||
|             S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); | ||||
|         } | ||||
| 
 | ||||
|         // Clear padding elements to -inf, so they don't contribute to rowmax | ||||
|         if (Clamp != 0 && | ||||
|             ((j + 1) * Bc > KV || | ||||
|              (i + 1) * Br > N)) { | ||||
| 
 | ||||
|             uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br; | ||||
|             uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; | ||||
| 
 | ||||
|             coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); | ||||
|         } | ||||
| 
 | ||||
|         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM; | ||||
| 
 | ||||
|         coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); | ||||
| 
 | ||||
|         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M; | ||||
| 
 | ||||
|         // M = max(rowmax, Mold) | ||||
|         // P = e^(S - M) | ||||
|         // eM = e^(Mold - M) | ||||
|         coopMatPerElementNV(M, rowmax, Max, Mold); | ||||
|         coopMatPerElementNV(P, S - M, Exp); | ||||
|         coopMatPerElementNV(eM, Mold - M, Exp); | ||||
| 
 | ||||
|         // Clear padding elements to 0, so they don't contribute to rowsum | ||||
|         if (Clamp != 0 && | ||||
|             ((j + 1) * Bc > KV || | ||||
|              (i + 1) * Br > N)) { | ||||
| 
 | ||||
|             uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br; | ||||
|             uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; | ||||
| 
 | ||||
|             coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); | ||||
|         } | ||||
| 
 | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); | ||||
| 
 | ||||
|         // compute rowsum by multiplying by matrix of all ones. | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); | ||||
| 
 | ||||
|         rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0); | ||||
|         rowsum = coopMatMulAdd(P_A, One, rowsum); | ||||
| 
 | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V; | ||||
|         uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; | ||||
|         coopMatLoadTensorNV(V,  data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); | ||||
| 
 | ||||
|         L = eM*L + rowsum; | ||||
| 
 | ||||
|         // This is the "diagonal" matrix in the paper, but since we do componentwise | ||||
|         // multiply rather than matrix multiply it has the diagonal element smeared | ||||
|         // across the row | ||||
|         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag; | ||||
| 
 | ||||
|         // resize eM by using smear/reduce | ||||
|         coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); | ||||
| 
 | ||||
|         O = eMdiag * O; | ||||
| 
 | ||||
|         O = coopMatMulAdd(P_A, V, O); | ||||
|     } | ||||
| 
 | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag; | ||||
| 
 | ||||
|     // resize L by using smear/reduce | ||||
|     coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); | ||||
| 
 | ||||
|     [[unroll]] | ||||
|     for (int k = 0; k < Ldiag.length(); ++k) { | ||||
|         Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; | ||||
|     } | ||||
| 
 | ||||
|     O = Ldiag*O; | ||||
| 
 | ||||
|     tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); | ||||
|     tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); | ||||
| 
 | ||||
|     // permute dimensions | ||||
|     tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); | ||||
|     uint32_t o_offset = iq3*p.ne2*p.ne1; | ||||
| 
 | ||||
|     coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O); | ||||
|     coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); | ||||
| } | ||||
							
								
								
									
										328
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										328
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,328 @@ | |||
| #version 450 | ||||
| 
 | ||||
| #extension GL_EXT_control_flow_attributes : enable | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
| 
 | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require | ||||
| 
 | ||||
| #extension GL_KHR_memory_scope_semantics : enable | ||||
| #extension GL_KHR_cooperative_matrix : enable | ||||
| #extension GL_NV_cooperative_matrix2 : enable | ||||
| #extension GL_EXT_buffer_reference : enable | ||||
| #extension GL_KHR_shader_subgroup_ballot : enable | ||||
| #extension GL_KHR_shader_subgroup_vote : enable | ||||
| 
 | ||||
| #include "types.comp" | ||||
| 
 | ||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||
| 
 | ||||
| layout (constant_id = 1) const uint BM = 64; | ||||
| layout (constant_id = 2) const uint BN = 64; | ||||
| layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant | ||||
| 
 | ||||
| layout (push_constant) uniform parameter | ||||
| { | ||||
|     uint M; | ||||
|     uint N; | ||||
|     uint K; | ||||
|     uint stride_a; | ||||
|     uint stride_b; | ||||
|     uint stride_d; | ||||
| 
 | ||||
|     uint batch_stride_a; | ||||
|     uint batch_stride_b; | ||||
|     uint batch_stride_d; | ||||
| 
 | ||||
| #ifdef MUL_MAT_ID | ||||
|     uint nei0; | ||||
|     uint nei1; | ||||
|     uint nbi1; | ||||
|     uint ne11; | ||||
| #else | ||||
|     uint k_split; | ||||
|     uint ne02; | ||||
|     uint ne12; | ||||
|     uint broadcast2; | ||||
|     uint broadcast3; | ||||
| #endif | ||||
| } p; | ||||
| 
 | ||||
| 
 | ||||
| layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; | ||||
| layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; | ||||
| layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; | ||||
| 
 | ||||
| #if QUANT_K > 1 | ||||
| #define DECODEFUNCA , dequantFuncA | ||||
| #define MAT_A_TYPE float16_t | ||||
| 
 | ||||
| #include "dequant_funcs_cm2.comp" | ||||
| 
 | ||||
| #else | ||||
| #define DECODEFUNCA | ||||
| #define MAT_A_TYPE A_TYPE | ||||
| #endif | ||||
| 
 | ||||
| #define MAT_B_TYPE B_TYPE | ||||
| 
 | ||||
| #ifdef MUL_MAT_ID | ||||
| layout (binding = 3) readonly buffer IDS {int data_ids[];}; | ||||
| 
 | ||||
| shared u16vec4 row_ids[3072]; | ||||
| 
 | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { | ||||
|    B_TYPE b[]; | ||||
| }; | ||||
| 
 | ||||
| uint _ne1; | ||||
| shared uint _ne1_sh; | ||||
| 
 | ||||
| B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint row_i = blockCoords[0]; | ||||
| 
 | ||||
|     if (row_i >= _ne1) { | ||||
|         return B_TYPE(0.0); | ||||
|     } | ||||
| 
 | ||||
|     const u16vec4 row_idx = row_ids[row_i]; | ||||
|     B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; | ||||
| 
 | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) | ||||
| { | ||||
|     uint dr = ir * BM + r; | ||||
|     uint dc = ic * BN + c; | ||||
| 
 | ||||
|     if (dr < p.M && dc < _ne1) { | ||||
|         uint row_i = dc; | ||||
|         const u16vec4 row_idx = row_ids[row_i]; | ||||
|         data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; | ||||
|     } | ||||
|     return elem; | ||||
| } | ||||
| 
 | ||||
| #endif | ||||
| 
 | ||||
| void main() { | ||||
| #if defined(DATA_A_IQ4_NL) | ||||
|     init_iq4nl_shmem(); | ||||
| #endif | ||||
| 
 | ||||
| #ifdef MUL_MAT_ID | ||||
|     const uint expert_idx = gl_GlobalInvocationID.z; | ||||
| #else | ||||
|     const uint batch_idx = gl_GlobalInvocationID.z; | ||||
| 
 | ||||
|     const uint i13 = batch_idx / p.ne12; | ||||
|     const uint i12 = batch_idx % p.ne12; | ||||
| 
 | ||||
|     const uint i03 = i13 / p.broadcast3; | ||||
|     const uint i02 = i12 / p.broadcast2; | ||||
| 
 | ||||
|     const uint batch_idx_a = i03 * p.ne02 + i02; | ||||
| #endif | ||||
| 
 | ||||
|     const uint blocks_m = (p.M + BM - 1) / BM; | ||||
|     const uint ir = gl_WorkGroupID.x % blocks_m; | ||||
|     const uint ik = gl_WorkGroupID.x / blocks_m; | ||||
|     const uint ic = gl_WorkGroupID.y; | ||||
| 
 | ||||
| #ifdef MUL_MAT_ID | ||||
|     // Spread the search across all elements in the first subgroup | ||||
|     if (gl_SubgroupID == 0) { | ||||
|         _ne1 = 0; | ||||
|         uint num_elements = p.nei1 * p.nei0; | ||||
| 
 | ||||
|         for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { | ||||
|             bool in_range = i < num_elements; | ||||
|             uint ii0 = i % p.nei0; | ||||
|             uint ii1 = i / p.nei0; | ||||
|             uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; | ||||
|             uvec4 ballot = subgroupBallot(in_range && id == expert_idx); | ||||
|             uint idx = subgroupBallotExclusiveBitCount(ballot); | ||||
|             if (in_range && id == expert_idx) { | ||||
|                 row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); | ||||
|             } | ||||
|             _ne1 += subgroupBallotBitCount(ballot); | ||||
|         } | ||||
|         _ne1_sh = _ne1; | ||||
|     } | ||||
| 
 | ||||
|     barrier(); | ||||
| 
 | ||||
|     _ne1 = _ne1_sh; | ||||
| 
 | ||||
|     // Workgroup has no work | ||||
|     if (ic * BN >= _ne1) return; | ||||
| #endif | ||||
| 
 | ||||
| #ifdef MUL_MAT_ID | ||||
|     uint start_k = 0; | ||||
|     const uint end_k = p.K; | ||||
| #else | ||||
|     uint start_k = ik * p.k_split; | ||||
|     const uint end_k = min(p.K, (ik + 1) * p.k_split); | ||||
| #endif | ||||
| 
 | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum; | ||||
|     sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0); | ||||
| 
 | ||||
| #ifdef MUL_MAT_ID | ||||
|     uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; | ||||
|     uint pos_b = 0; | ||||
| #else | ||||
|     uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; | ||||
|     uint pos_b = batch_idx * p.batch_stride_b; | ||||
| #endif | ||||
| 
 | ||||
|     uint stride_a = p.stride_a / QUANT_K; | ||||
|     uint stride_b = p.stride_b; | ||||
| 
 | ||||
|     // Hint to the compiler that values are aligned (want 16B alignment). | ||||
|     // Quants are always block-aligned, no alignment needed. | ||||
| #if ALIGNED | ||||
| #if QUANT_K == 1 | ||||
|     stride_a &= ~7; | ||||
| #endif | ||||
|     stride_b &= ~7; | ||||
| #endif | ||||
| 
 | ||||
|     // Create layouts for both clamped and unclamped accesses | ||||
|     tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); | ||||
|     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|     tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); | ||||
|     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
| 
 | ||||
| #if QUANT_K > 1 | ||||
|     tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); | ||||
|     tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); | ||||
| #endif | ||||
| 
 | ||||
|     // Use end_k rather than p.K as the dimension because that's what | ||||
|     // we need to bound check against when using split_k | ||||
|     tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); | ||||
|     tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); | ||||
|     tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); | ||||
|     tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); | ||||
|     tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); | ||||
| 
 | ||||
|     tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); | ||||
| 
 | ||||
| #if !defined(MUL_MAT_ID) | ||||
|     // Detect a fast path where all loads are entirely in bounds and no clamping is required | ||||
|     if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && | ||||
| #if QUANT_K == 1 | ||||
|         (stride_a % 8) == 0 && | ||||
| #endif | ||||
|         (stride_b % 8) == 0 && (start_k % 8) == 0) { | ||||
|         // Hint to the compiler that values are aligned (want 16B alignment) | ||||
|         start_k &= ~7; | ||||
|         stride_b &= ~7; | ||||
| #if QUANT_K == 1 | ||||
|         stride_a &= ~7; | ||||
| #endif | ||||
| 
 | ||||
|         tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); | ||||
|         tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); | ||||
| 
 | ||||
|         uint k_iters = (end_k - start_k + BK - 1) / BK; | ||||
| 
 | ||||
|         for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { | ||||
| 
 | ||||
|             coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a; | ||||
|             coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b; | ||||
| 
 | ||||
|             coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); | ||||
|             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
| 
 | ||||
|             coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); | ||||
|             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
| 
 | ||||
|             sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|         } | ||||
|     } else | ||||
| #endif // !defined(MUL_MAT_ID) | ||||
|     { | ||||
|         tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); | ||||
| 
 | ||||
|         tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); | ||||
| 
 | ||||
|         tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); | ||||
| 
 | ||||
|         tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); | ||||
| 
 | ||||
|         [[dont_unroll]] | ||||
|         for (uint block_k = start_k; block_k < end_k; block_k += BK) { | ||||
| 
 | ||||
|             coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a; | ||||
|             coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b; | ||||
|             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft; | ||||
|             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft; | ||||
| 
 | ||||
|             // Clamping is expensive, so detect different code paths for each combination | ||||
|             // of A and B needing clamping. | ||||
|             bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; | ||||
| #ifdef MUL_MAT_ID | ||||
|             bool unclampedB = true; | ||||
| #else | ||||
|             bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; | ||||
| #endif | ||||
|             if (unclampedA && unclampedB) { | ||||
|                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); | ||||
| #ifdef MUL_MAT_ID | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); | ||||
| #else | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); | ||||
| #endif | ||||
|                 mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
|                 mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
|                 sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|             } else if (unclampedA && !unclampedB) { | ||||
|                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); | ||||
| 
 | ||||
|                 mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
|                 mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
|                 sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|             } else if (!unclampedA && unclampedB) { | ||||
|                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); | ||||
| #ifdef MUL_MAT_ID | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); | ||||
| #else | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); | ||||
| #endif | ||||
|                 mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
|                 mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
|                 sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|             } else if (!unclampedA && !unclampedB) { | ||||
|                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); | ||||
| 
 | ||||
|                 mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
|                 mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
|                 sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     // Convert from ACC_TYPE to D_TYPE | ||||
|     coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d; | ||||
|     mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum); | ||||
| 
 | ||||
| #ifdef MUL_MAT_ID | ||||
|     // Call callback to store each element, remapping row through shared memory | ||||
|     coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); | ||||
| #else | ||||
|     tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); | ||||
| 
 | ||||
|     uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; | ||||
|     coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); | ||||
| #endif | ||||
| } | ||||
|  | @ -30,6 +30,8 @@ | |||
|     #include <fcntl.h> | ||||
| #endif | ||||
| 
 | ||||
| #include <vulkan/vulkan_core.h> | ||||
| 
 | ||||
| #define ASYNCIO_CONCURRENCY 64 | ||||
| 
 | ||||
| std::mutex lock; | ||||
|  | @ -196,15 +198,17 @@ static uint32_t compile_count = 0; | |||
| static std::mutex compile_count_mutex; | ||||
| static std::condition_variable compile_count_cond; | ||||
| 
 | ||||
| void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) { | ||||
|     std::string name = _name + (fp16 ? "" : "_fp32"); | ||||
| void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { | ||||
|     std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); | ||||
|     std::string out_fname = join_paths(output_dir, name + ".spv"); | ||||
|     std::string in_path = join_paths(input_dir, in_fname); | ||||
| 
 | ||||
|     std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; | ||||
| 
 | ||||
|     #ifdef _WIN32 | ||||
|         std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; | ||||
|         std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; | ||||
|     #else | ||||
|         std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o",  out_fname}; | ||||
|         std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", in_path, "-o",  out_fname}; | ||||
|     #endif | ||||
| 
 | ||||
|     #ifdef GGML_VULKAN_SHADER_DEBUG_INFO | ||||
|  | @ -254,7 +258,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s | |||
| } | ||||
| 
 | ||||
| static std::vector<std::future<void>> compiles; | ||||
| void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) { | ||||
| void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { | ||||
|     { | ||||
|         // wait until fewer than N compiles are in progress.
 | ||||
|         // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
 | ||||
|  | @ -265,15 +269,15 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const | |||
|         } | ||||
|         compile_count++; | ||||
|     } | ||||
|     compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16)); | ||||
|     compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc)); | ||||
| } | ||||
| 
 | ||||
| void matmul_shaders(bool fp16, bool matmul_id) { | ||||
|     std::string load_vec = fp16 ? "8" : "4"; | ||||
|     std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4"; | ||||
|     std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4"; | ||||
| void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) { | ||||
|     std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; | ||||
|     std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; | ||||
|     std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; | ||||
| 
 | ||||
|     std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}}; | ||||
|     std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; | ||||
|     std::string shader_name = "matmul"; | ||||
| 
 | ||||
|     if (matmul_id) { | ||||
|  | @ -285,21 +289,31 @@ void matmul_shaders(bool fp16, bool matmul_id) { | |||
|         base_dict["FLOAT16"] = "1"; | ||||
|     } | ||||
| 
 | ||||
|     // Shaders with f16 B_TYPE
 | ||||
|     string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16); | ||||
|     string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16); | ||||
|     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; | ||||
| 
 | ||||
|     string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16); | ||||
|     string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16); | ||||
|     std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; | ||||
| 
 | ||||
|     // Shaders with f16 B_TYPE
 | ||||
|     string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc); | ||||
|     string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); | ||||
| 
 | ||||
|     string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); | ||||
|     string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc); | ||||
| 
 | ||||
|     for (const auto& tname : type_names) { | ||||
|         std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||
|         // For unaligned, load one at a time for f32/f16, or two at a time for quants
 | ||||
|         std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2"; | ||||
|         std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; | ||||
|         // For aligned matmul loads
 | ||||
|         std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2"; | ||||
|         string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16); | ||||
|         string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16); | ||||
|         std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; | ||||
| 
 | ||||
|         string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); | ||||
|         string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); | ||||
| 
 | ||||
|         if (tname != "f16" && tname != "f32") { | ||||
|             string_to_spv(shader_name + "_" + tname + "_f16", source_name,          merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); | ||||
|             string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  | @ -307,11 +321,50 @@ void process_shaders() { | |||
|     std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; | ||||
|     std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}}; | ||||
| 
 | ||||
|     // matmul
 | ||||
|     for (const auto& fp16 : {false, true}) { | ||||
|         matmul_shaders(fp16, false); | ||||
|         matmul_shaders(fp16, true); | ||||
|         for (const auto& matmul_id : {false, true}) { | ||||
|             for (const auto& coopmat2 : {false, true}) { | ||||
|                 for (const auto& f16acc : {false, true}) { | ||||
| #if !defined(VK_NV_cooperative_matrix2) | ||||
|                     if (coopmat2) { | ||||
|                         continue; | ||||
|                     } | ||||
| #endif | ||||
|                     if (coopmat2 && !fp16) { | ||||
|                         continue; | ||||
|                     } | ||||
|                     if (!coopmat2 && f16acc) { | ||||
|                         continue; | ||||
|                     } | ||||
|                     matmul_shaders(fp16, matmul_id, coopmat2, f16acc); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| #if defined(VK_NV_cooperative_matrix2) | ||||
|     // flash attention
 | ||||
|     for (const auto& f16acc : {false, true}) { | ||||
|         std::string acctype = f16acc ? "float16_t" : "float"; | ||||
| 
 | ||||
|         for (const auto& tname : type_names) { | ||||
|             if (tname == "f32") { | ||||
|                 continue; | ||||
|             } | ||||
| 
 | ||||
|             if (tname == "f16") { | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||
|                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc); | ||||
|             } else { | ||||
|                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||
|                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
| 
 | ||||
|     for (const auto& tname : type_names) { | ||||
|         // mul mat vec
 | ||||
|         std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue