Fix q6_k dequant shader for AMD
This commit is contained in:
parent
da09a02b81
commit
39bd512dd1
2 changed files with 21 additions and 18 deletions
|
@ -468,7 +468,11 @@ layout (push_constant) uniform parameter
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const int i = int(gl_WorkGroupID.x);
|
for (int wgy = 0; wgy < 256; wgy++) {
|
||||||
|
const int i = int(gl_WorkGroupID.x * 256 + wgy);
|
||||||
|
if (i >= p.M * p.K / QUANT_K) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int tid = int(gl_LocalInvocationID.x);
|
const int tid = int(gl_LocalInvocationID.x);
|
||||||
const int ip = tid / 32;
|
const int ip = tid / 32;
|
||||||
const int il = tid - 32 * ip;
|
const int il = tid - 32 * ip;
|
||||||
|
@ -485,6 +489,7 @@ void main() {
|
||||||
y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
|
y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
|
||||||
y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
|
y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
|
||||||
y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
|
y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
|
|
@ -797,20 +797,16 @@ static void ggml_vk_generate_shaders() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int work_group_denom;
|
|
||||||
|
|
||||||
switch ((ggml_type)i) {
|
switch ((ggml_type)i) {
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
stream << dequant_q6_K_body;
|
stream << dequant_q6_K_body;
|
||||||
work_group_denom = 64 * 4;
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
stream << dequant_body;
|
stream << dequant_body;
|
||||||
work_group_denom = 256 * 32;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {work_group_denom, 1, 1}, {}, 1);
|
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mul mat vec
|
// mul mat vec
|
||||||
|
@ -891,6 +887,8 @@ void ggml_vk_init(void) {
|
||||||
};
|
};
|
||||||
validation_features.setPNext(nullptr);
|
validation_features.setPNext(nullptr);
|
||||||
instance_create_info.setPNext(&validation_features);
|
instance_create_info.setPNext(&validation_features);
|
||||||
|
|
||||||
|
std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
vk_instance = vk::createInstance(instance_create_info);
|
vk_instance = vk::createInstance(instance_create_info);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue