Vulkan: Fix float16 use on devices without float16 support + fix subgroup_size_control validation error (#11161)
* Vulkan: Remove float16 use in shaders * Fix validation error about subgroup_size_control extension
This commit is contained in:
parent
ee7136c6d1
commit
c3f9d25706
9 changed files with 50 additions and 51 deletions
|
@ -1,6 +1,6 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
|
@ -77,10 +77,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
uvec4 q3 = uvec4(unpack8(q3_u32));
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
|
||||
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
|
||||
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
|
||||
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
|
||||
vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
|
||||
vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
|
||||
vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
|
||||
vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 4; ++l) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue