amx : minor opt
ggml-ci
This commit is contained in:
parent
8bfef91b8b
commit
b14b9bf692
2 changed files with 2 additions and 3 deletions
|
@ -78,7 +78,6 @@ inline void parallel_for_ggml(const ggml_compute_params * params, int n, const f
|
||||||
int tbegin, tend;
|
int tbegin, tend;
|
||||||
balance211(n, params->nth, params->ith, tbegin, tend);
|
balance211(n, params->nth, params->ith, tbegin, tend);
|
||||||
f(tbegin, tend);
|
f(tbegin, tend);
|
||||||
ggml_barrier(params->threadpool); // TODO: might not always be needed
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// quantized types that have AMX support
|
// quantized types that have AMX support
|
||||||
|
|
|
@ -1349,10 +1349,10 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
|
||||||
constexpr int row = idx / COLS;
|
constexpr int row = idx / COLS;
|
||||||
constexpr int col = idx % COLS;
|
constexpr int col = idx % COLS;
|
||||||
|
|
||||||
if (col == 0) {
|
if constexpr (col == 0) {
|
||||||
va = _mm512_loadu_ps(A + row * K + k);
|
va = _mm512_loadu_ps(A + row * K + k);
|
||||||
}
|
}
|
||||||
if (row == 0) {
|
if constexpr (row == 0) {
|
||||||
vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));
|
vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));
|
||||||
}
|
}
|
||||||
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
|
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue