Remove redundant checks
This commit is contained in:
parent
3a58a0159b
commit
f54afb4f12
3 changed files with 2 additions and 12 deletions
|
@ -3413,9 +3413,7 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
|
||||||
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type type_a) {
|
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type type_a) {
|
||||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
||||||
|
|
||||||
// On F32 matmuls, selecting this way increases performance significantly. On quants or fp16, it reduces performance.
|
if (ctx->device->coopmat2) {
|
||||||
// Maybe because it reduces checks and uses more vector loads, but why is fp16 worse?
|
|
||||||
if (ctx->device->coopmat2 || type_a == GGML_TYPE_F32) {
|
|
||||||
if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n & mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
|
if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n & mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
|
||||||
return aligned ? mmp->a_l : mmp->l;
|
return aligned ? mmp->a_l : mmp->l;
|
||||||
}
|
}
|
||||||
|
@ -3468,7 +3466,7 @@ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ct
|
||||||
if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
|
if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
|
||||||
return aligned ? mmp->a_s : mmp->s;
|
return aligned ? mmp->a_s : mmp->s;
|
||||||
}
|
}
|
||||||
if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64 || ctx->device->coopmat_support)) || !ctx->device->mul_mat_id_l) {
|
if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
|
||||||
return aligned ? mmp->a_m : mmp->m;
|
return aligned ? mmp->a_m : mmp->m;
|
||||||
}
|
}
|
||||||
return aligned ? mmp->a_l : mmp->l;
|
return aligned ? mmp->a_l : mmp->l;
|
||||||
|
|
|
@ -580,12 +580,8 @@ void main() {
|
||||||
|
|
||||||
if (is_aligned && is_in_bounds) {
|
if (is_aligned && is_in_bounds) {
|
||||||
// Full coopMat is within bounds and stride_d is aligned with 16B
|
// Full coopMat is within bounds and stride_d is aligned with 16B
|
||||||
#ifdef ACC_F16
|
|
||||||
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
||||||
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
||||||
#else
|
|
||||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
|
||||||
#endif
|
|
||||||
} else if (is_in_bounds) {
|
} else if (is_in_bounds) {
|
||||||
// Full coopMat is within bounds, but stride_d is not aligned
|
// Full coopMat is within bounds, but stride_d is not aligned
|
||||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||||
|
|
|
@ -292,10 +292,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||||
|
|
||||||
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||||
|
|
||||||
if (f16acc) {
|
|
||||||
base_dict["ACC_F16"] = "1";
|
|
||||||
}
|
|
||||||
|
|
||||||
if (coopmat) {
|
if (coopmat) {
|
||||||
base_dict["COOPMAT"] = "1";
|
base_dict["COOPMAT"] = "1";
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue