vulkan: request round-to-even for fp16 in im2col/rope_head (#10767)
Vulkan doesn't mandate a specific rounding mode, but the shader_float_controls feature allows rounding mode to be requested if the implementation supports it.
This commit is contained in:
		
							parent
							
								
									dafae66cc2
								
							
						
					
					
						commit
						b685daf386
					
				
					 4 changed files with 31 additions and 5 deletions
				
			
		|  | @ -1,6 +1,11 @@ | |||
| #version 450 | ||||
| 
 | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
| #extension GL_EXT_spirv_intrinsics: enable | ||||
| 
 | ||||
| #if RTE16 | ||||
| spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits | ||||
| #endif | ||||
| 
 | ||||
| layout (push_constant) uniform parameter | ||||
| { | ||||
|  |  | |||
|  | @ -1,6 +1,11 @@ | |||
| #include "types.comp" | ||||
| 
 | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
| #extension GL_EXT_spirv_intrinsics: enable | ||||
| 
 | ||||
| #if RTE16 | ||||
| spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits | ||||
| #endif | ||||
| 
 | ||||
| layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; | ||||
| 
 | ||||
|  |  | |||
|  | @ -461,9 +461,11 @@ void process_shaders() { | |||
| 
 | ||||
|     string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); | ||||
|     string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); | ||||
| 
 | ||||
|     string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); | ||||
|     string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); | ||||
| 
 | ||||
|     string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); | ||||
| 
 | ||||
|  | @ -471,6 +473,7 @@ void process_shaders() { | |||
| 
 | ||||
|     string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); | ||||
|     string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); | ||||
|     string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); | ||||
| 
 | ||||
|     string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue