Fix q6_k dequant shader for AMD

This commit is contained in:
0cc4m 2023-10-03 09:31:54 +02:00
parent da09a02b81
commit 39bd512dd1
2 changed files with 21 additions and 18 deletions

View file

@ -468,23 +468,28 @@ layout (push_constant) uniform parameter
} p;
void main() {
const int i = int(gl_WorkGroupID.x);
const int tid = int(gl_LocalInvocationID.x);
const int ip = tid / 32;
const int il = tid - 32 * ip;
const int is = 8 * ip + il / 16;
for (int wgy = 0; wgy < 256; wgy++) {
const int i = int(gl_WorkGroupID.x * 256 + wgy);
if (i >= p.M * p.K / QUANT_K) {
return;
}
const int tid = int(gl_LocalInvocationID.x);
const int ip = tid / 32;
const int il = tid - 32 * ip;
const int is = 8 * ip + il / 16;
const int y_idx = i * QUANT_K + 128 * ip + il;
const int y_idx = i * QUANT_K + 128 * ip + il;
const int ql_idx = 64 * ip + il;
const uint8_t qh = x[i].qh[32 * ip + il];
const int ql_idx = 64 * ip + il;
const uint8_t qh = x[i].qh[32 * ip + il];
const FLOAT_TYPE d = FLOAT_TYPE(x[i].d);
const FLOAT_TYPE d = FLOAT_TYPE(x[i].d);
y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
}
}
)";

View file

@ -797,20 +797,16 @@ static void ggml_vk_generate_shaders() {
continue;
}
int work_group_denom;
switch ((ggml_type)i) {
case GGML_TYPE_Q6_K:
stream << dequant_q6_K_body;
work_group_denom = 64 * 4;
break;
default:
stream << dequant_body;
work_group_denom = 256 * 32;
break;
}
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {work_group_denom, 1, 1}, {}, 1);
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
}
// mul mat vec
@ -891,6 +887,8 @@ void ggml_vk_init(void) {
};
validation_features.setPNext(nullptr);
instance_create_info.setPNext(&validation_features);
std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
#endif
vk_instance = vk::createInstance(instance_create_info);