Fix q6_k dequant shader for AMD
This commit is contained in:
parent
da09a02b81
commit
39bd512dd1
2 changed files with 21 additions and 18 deletions
|
@ -468,23 +468,28 @@ layout (push_constant) uniform parameter
|
|||
} p;
|
||||
|
||||
void main() {
|
||||
const int i = int(gl_WorkGroupID.x);
|
||||
const int tid = int(gl_LocalInvocationID.x);
|
||||
const int ip = tid / 32;
|
||||
const int il = tid - 32 * ip;
|
||||
const int is = 8 * ip + il / 16;
|
||||
for (int wgy = 0; wgy < 256; wgy++) {
|
||||
const int i = int(gl_WorkGroupID.x * 256 + wgy);
|
||||
if (i >= p.M * p.K / QUANT_K) {
|
||||
return;
|
||||
}
|
||||
const int tid = int(gl_LocalInvocationID.x);
|
||||
const int ip = tid / 32;
|
||||
const int il = tid - 32 * ip;
|
||||
const int is = 8 * ip + il / 16;
|
||||
|
||||
const int y_idx = i * QUANT_K + 128 * ip + il;
|
||||
const int y_idx = i * QUANT_K + 128 * ip + il;
|
||||
|
||||
const int ql_idx = 64 * ip + il;
|
||||
const uint8_t qh = x[i].qh[32 * ip + il];
|
||||
const int ql_idx = 64 * ip + il;
|
||||
const uint8_t qh = x[i].qh[32 * ip + il];
|
||||
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(x[i].d);
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(x[i].d);
|
||||
|
||||
y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
|
||||
y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
|
||||
y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
|
||||
y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
|
||||
y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
|
||||
y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
|
||||
y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
|
||||
y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
|
|
|
@ -797,20 +797,16 @@ static void ggml_vk_generate_shaders() {
|
|||
continue;
|
||||
}
|
||||
|
||||
int work_group_denom;
|
||||
|
||||
switch ((ggml_type)i) {
|
||||
case GGML_TYPE_Q6_K:
|
||||
stream << dequant_q6_K_body;
|
||||
work_group_denom = 64 * 4;
|
||||
break;
|
||||
default:
|
||||
stream << dequant_body;
|
||||
work_group_denom = 256 * 32;
|
||||
break;
|
||||
}
|
||||
|
||||
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {work_group_denom, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
|
||||
}
|
||||
|
||||
// mul mat vec
|
||||
|
@ -891,6 +887,8 @@ void ggml_vk_init(void) {
|
|||
};
|
||||
validation_features.setPNext(nullptr);
|
||||
instance_create_info.setPNext(&validation_features);
|
||||
|
||||
std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
|
||||
#endif
|
||||
vk_instance = vk::createInstance(instance_create_info);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue