Make improvements

- Introduce portable sched_getcpu() api
- Support GCC's __target_clones__ feature
- Make fma() go faster on x86 in default mode
- Remove some asan checks from core libraries
- WinMain() now ensures $HOME and $USER are defined
This commit is contained in:
Justine Tunney 2024-02-01 03:39:46 -08:00
parent d5225a693b
commit 2ab9e9f7fd
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
192 changed files with 2809 additions and 932 deletions

View file

@ -16,143 +16,330 @@
TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.
*/
#include "libc/calls/calls.h"
#include "libc/calls/struct/timespec.h"
#include "libc/fmt/itoa.h"
#include "libc/inttypes.h"
#include "libc/runtime/runtime.h"
#include "libc/stdio/stdio.h"
#include "libc/str/str.h"
#include "libc/sysv/consts/clock.h"
#include "third_party/double-conversion/double-to-string.h"
#include "third_party/double-conversion/utils.h"
#include "third_party/openmp/omp.h"
#include <algorithm>
#include <atomic>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#ifndef __FAST_MATH__
#define FLAWLESS 0
#else
#define FLAWLESS 1e-05
#endif
#define PRECISION 2e-6
#define LV1DCACHE 49152
#define THRESHOLD 3000000
#if defined(__OPTIMIZE__) && !defined(__SANITIZE_ADDRESS__)
#define ITERATIONS 10
#define ITERATIONS 5
#else
#define ITERATIONS 1
#endif
// m×n → (m×n)ᵀ
template <typename T>
void transpose(long m, long n, const T *A, long sa, T *B, long sb) {
#pragma omp parallel for collapse(2)
for (long i = 0; i < m; ++i) {
#define OPTIMIZED __attribute__((__optimize__("-O3,-ffast-math")))
#define PORTABLE \
__target_clones("arch=znver4," \
"arch=znver3," \
"arch=sapphirerapids," \
"arch=alderlake," \
"arch=rocketlake," \
"arch=cooperlake," \
"arch=tigerlake," \
"arch=cascadelake," \
"arch=skylake-avx512," \
"arch=skylake," \
"arch=znver1," \
"arch=tremont," \
"fma," \
"avx")
static bool is_self_testing;
// m×n → n×m
template <typename TA, typename TB>
void transpose(long m, long n, const TA *A, long lda, TB *B, long ldb) {
#pragma omp parallel for collapse(2) if (m * n > THRESHOLD)
for (long i = 0; i < m; ++i)
for (long j = 0; j < n; ++j) {
B[sb * j + i] = A[sa * i + j];
B[ldb * j + i] = A[lda * i + j];
}
}
}
// m×k * k×n → m×n
template <typename T>
void matmul(long m, long n, long k, const T *A, long sa, const T *B, long sb,
T *C, long sc) {
#pragma omp parallel for collapse(2)
for (long i = 0; i < m; ++i) {
// k×m * k×n → m×n if aT
// m×k * n×k → m×n if bT
// k×m * n×k → m×n if aT and bT
template <typename TC, typename TA, typename TB>
void dgemm(bool aT, bool bT, long m, long n, long k, float alpha, const TA *A,
long lda, const TB *B, long ldb, float beta, TC *C, long ldc) {
#pragma omp parallel for collapse(2) if (m * n * k > THRESHOLD)
for (long i = 0; i < m; ++i)
for (long j = 0; j < n; ++j) {
T sum = 0;
for (long l = 0; l < k; ++l) {
sum += A[sa * i + l] * B[sb * l + j];
}
C[sc * i + j] = sum;
double sum = 0;
for (long l = 0; l < k; ++l)
sum = std::fma((aT ? A[lda * l + i] : A[lda * i + l]) * alpha,
(bT ? B[ldb * j + l] : B[ldb * l + j]), sum);
C[ldc * i + j] = beta * C[ldc * i + j] + sum;
}
}
}
template <long BM, long BN, typename T>
void gemmk(long k, const T *A, long sa, const T *B, long sb, T *C, long sc) {
T S[BM][BN] = {0};
for (long l = 0; l < k; ++l) {
for (long i = 0; i < BM; ++i) {
for (long j = 0; j < BN; ++j) {
S[i][j] += A[sa * l + i] * B[sb * l + j];
template <typename T, typename TC, typename TA, typename TB>
struct Gemmlin {
public:
Gemmlin(bool aT, bool bT, float alpha, const TA *A, long lda, const TB *B,
long ldb, float beta, TC *C, long ldc)
: aT(aT),
bT(bT),
alpha(alpha),
A(A),
lda(lda),
B(B),
ldb(ldb),
beta(beta),
C(C),
ldc(ldc) {
}
void gemm(long m, long n, long k) {
if (!m || !n) return;
for (long i = 0; i < m; ++i)
for (long j = 0; j < n; ++j) {
C[ldc * i + j] *= beta;
}
if (!k) return;
cub = sqrt(LV1DCACHE) / sqrt(sizeof(T) * 3);
mnpack(0, m, 0, n, 0, k);
}
private:
void mnpack(long m0, long m, //
long n0, long n, //
long k0, long k) {
long mc = rounddown(std::min(m - m0, cub), 4);
long mp = m0 + (m - m0) / mc * mc;
long nc = rounddown(std::min(n - n0, cub), 4);
long np = n0 + (n - n0) / nc * nc;
long kc = rounddown(std::min(k - k0, cub), 4);
long kp = k0 + (k - k0) / kc * kc;
kpack(m0, mc, mp, n0, nc, np, k0, kc, k, kp);
if (m - mp) mnpack(mp, m, n0, np, k0, k);
if (n - np) mnpack(m0, mp, np, n, k0, k);
if (m - mp && n - np) mnpack(mp, m, np, n, k0, k);
}
void kpack(long m0, long mc, long m, //
long n0, long nc, long n, //
long k0, long kc, long k, //
long kp) {
rpack(m0, mc, m, n0, nc, n, k0, kc, kp);
if (k - kp) rpack(m0, mc, m, n0, nc, n, kp, k - kp, k);
}
void rpack(long m0, long mc, long m, //
long n0, long nc, long n, //
long k0, long kc, long k) {
if (!(mc % 4) && !(nc % 4))
bgemm<4, 4>(m0, mc, m, n0, nc, n, k0, kc, k);
else
bgemm<1, 1>(m0, mc, m, n0, nc, n, k0, kc, k);
}
template <int mr, int nr>
void bgemm(long m0, long mc, long m, //
long n0, long nc, long n, //
long k0, long kc, long k) {
ops = (m - m0) * (n - n0) * (k - k0);
ml = (m - m0) / mc;
nl = (n - n0) / nc;
locks = new lock[ml * nl];
there_will_be_blocks<mr, nr>(m0, mc, m, n0, nc, n, k0, kc, k);
delete[] locks;
}
template <int mr, int nr>
void there_will_be_blocks(long m0, volatile long mc, long m, long n0, long nc,
long n, long k0, long kc, long k) {
#pragma omp parallel for collapse(2) if (ops > THRESHOLD && mc * kc > 16)
for (long ic = m0; ic < m; ic += mc)
for (long pc = k0; pc < k; pc += kc)
gizmo<mr, nr>(m0, mc, ic, n0, nc, k0, kc, pc, n);
}
template <int mr, int nr>
PORTABLE OPTIMIZED void gizmo(long m0, long mc, long ic, long n0, long nc,
long k0, long kc, long pc, long n) {
T Ac[mc / mr][kc][mr];
for (long i = 0; i < mc; ++i)
for (long j = 0; j < kc; ++j)
Ac[i / mr][j][i % mr] = alpha * (aT ? A[lda * (pc + j) + (ic + i)]
: A[lda * (ic + i) + (pc + j)]);
for (long jc = n0; jc < n; jc += nc) {
T Bc[nc / nr][nr][kc];
for (long j = 0; j < nc; ++j)
for (long i = 0; i < kc; ++i)
Bc[j / nr][j % nr][i] =
bT ? B[ldb * (jc + j) + (pc + i)] : B[ldb * (pc + i) + (jc + j)];
T Cc[nc / nr][mc / mr][nr][mr];
memset(Cc, 0, nc * mc * sizeof(float));
for (long jr = 0; jr < nc / nr; ++jr)
for (long ir = 0; ir < mc / mr; ++ir)
for (long pr = 0; pr < kc; ++pr)
for (long j = 0; j < nr; ++j)
for (long i = 0; i < mr; ++i)
Cc[jr][ir][j][i] += Ac[ir][pr][i] * Bc[jr][j][pr];
const long lk = nl * ((ic - m0) / mc) + ((jc - n0) / nc);
locks[lk].acquire();
for (long ir = 0; ir < mc; ir += mr)
for (long jr = 0; jr < nc; jr += nr)
for (long i = 0; i < mr; ++i)
for (long j = 0; j < nr; ++j)
C[ldc * (ic + ir + i) + (jc + jr + j)] +=
Cc[jr / nr][ir / mr][j][i];
locks[lk].release();
}
}
inline long rounddown(long x, long r) {
if (x < r)
return x;
else
return x & -r;
}
class lock {
public:
lock() = default;
void acquire() {
while (lock_.exchange(true, std::memory_order_acquire)) {
}
}
}
for (long i = 0; i < BM; ++i) {
for (long j = 0; j < BN; ++j) {
C[sc * i + j] = S[i][j];
void release() {
lock_.store(false, std::memory_order_release);
}
}
private:
std::atomic_bool lock_ = false;
};
bool aT;
bool bT;
float alpha;
const TA *A;
long lda;
const TB *B;
long ldb;
float beta;
TC *C;
long ldc;
long ops;
long nl;
long ml;
lock *locks;
long cub;
};
template <typename TC, typename TA, typename TB>
void sgemm(bool aT, bool bT, long m, long n, long k, float alpha, const TA *A,
long lda, const TB *B, long ldb, float beta, TC *C, long ldc) {
Gemmlin<float, TC, TA, TB> g{aT, bT, alpha, A, lda, B, ldb, beta, C, ldc};
g.gemm(m, n, k);
}
// (m×k)ᵀ * k×n → m×n
template <long BM, long BN, typename T>
void gemm(long m, long n, long k, const T *A, long sa, const T *B, long sb,
T *C, long sc) {
#pragma omp parallel for collapse(2)
for (long i = 0; i < m; i += BM) {
for (long j = 0; j < n; j += BN) {
gemmk<BM, BN>(k, A + i, sa, B + j, sb, C + sc * i + j, sc);
}
template <typename TA, typename TB>
void show(FILE *f, long max, long m, long n, const TA *A, long lda, const TB *B,
long ldb) {
flockfile(f);
fprintf(f, " ");
for (long j = 0; j < n; ++j) {
fprintf(f, "%13ld", j);
}
}
template <typename T>
void show(long m, long n, const T *A, long sa) {
long max = 4;
printf("{");
fprintf(f, "\n");
for (long i = 0; i < m; ++i) {
if (i) {
if (i == max) {
printf(", ...");
if (i == max) {
fprintf(f, "...\n");
break;
}
fprintf(f, "%5ld ", i);
for (long j = 0; j < n; ++j) {
if (j == max) {
fprintf(f, " ...");
break;
} else {
printf(", ");
}
char ba[16], bb[16];
sprintf(ba, "%13.7f", static_cast<double>(A[lda * i + j]));
sprintf(bb, "%13.7f", static_cast<double>(B[ldb * i + j]));
for (long k = 0; ba[k] && bb[k]; ++k) {
if (ba[k] != bb[k]) fputs_unlocked("\33[31m", f);
fputc_unlocked(ba[k], f);
if (ba[k] != bb[k]) fputs_unlocked("\33[0m", f);
}
}
printf("{");
for (long j = 0; j < n; ++j) {
if (j) {
if (j == max) {
printf(", ...");
break;
} else {
printf(", ");
}
}
printf("%g", static_cast<double>(A[j + i * sa]));
}
printf("}");
fprintf(f, "\n");
}
printf("}");
funlockfile(f);
}
template <typename T>
double diff(long m, long n, const T *A, long sa, const T *B, long sb) {
inline unsigned long GetDoubleBits(double f) {
union {
double f;
unsigned long i;
} u;
u.f = f;
return u.i;
}
inline bool IsNan(double x) {
return (GetDoubleBits(x) & (-1ull >> 1)) > (0x7ffull << 52);
}
template <typename TA, typename TB>
double diff(long m, long n, const TA *Want, long lda, const TB *Got, long ldb) {
double s = 0;
for (long i = 0; i < m; ++i) {
for (long j = 0; j < n; ++j) {
s += fabs(A[sa * i + j] - B[sb * i + j]);
}
}
return s / m / n;
int got_nans = 0;
int want_nans = 0;
for (long i = 0; i < m; ++i)
for (long j = 0; j < n; ++j)
if (IsNan(Want[ldb * i + j]))
++want_nans;
else if (IsNan(Got[ldb * i + j]))
++got_nans;
else
s += std::fabs(Want[lda * i + j] - Got[ldb * i + j]);
if (got_nans) printf("WARNING: got %d NaNs!\n", got_nans);
if (want_nans) printf("WARNING: want array has %d NaNs!\n", want_nans);
return s / (m * n);
}
template <typename T>
void check(double tol, long m, long n, const T *A, long sa, const T *B, long sb,
const char *file, long line) {
double sad = diff(m, n, A, sa, B, sb);
if (sad > tol) {
printf("%s:%d: sad %g exceeds %g\n\twant ", file, line, sad, tol);
show(m, n, A, sa);
printf("\n\t got ");
show(m, n, B, sb);
printf("\n");
template <typename TA, typename TB>
void show_error(FILE *f, long max, long m, long n, const TA *A, long lda,
const TB *B, long ldb, const char *file, int line, double sad,
double tol) {
fprintf(f, "%s:%d: sad %.17g exceeds %g\nwant\n", file, line, sad, tol);
show(f, max, m, n, A, lda, B, ldb);
fprintf(f, "got\n");
show(f, max, m, n, B, ldb, A, lda);
fprintf(f, "\n");
}
template <typename TA, typename TB>
void check(double tol, long m, long n, const TA *A, long lda, const TB *B,
long ldb, const char *file, int line) {
double sad = diff(m, n, A, lda, B, ldb);
if (sad <= tol) {
if (!is_self_testing) {
printf(" %g error\n", sad);
}
} else {
show_error(stderr, 16, m, n, A, lda, B, ldb, file, line, sad, tol);
const char *path = "/tmp/openmp_test.log";
FILE *f = fopen(path, "w");
if (f) {
show_error(f, 10000, m, n, A, lda, B, ldb, file, line, sad, tol);
printf("see also %s\n", path);
}
exit(1);
}
}
#define check(tol, m, n, A, sa, B, sb) \
check(tol, m, n, A, sa, B, sb, __FILE__, __LINE__)
#define check(tol, m, n, A, lda, B, ldb) \
check(tol, m, n, A, lda, B, ldb, __FILE__, __LINE__)
long micros(void) {
struct timespec ts;
@ -196,41 +383,91 @@ void fill(T *A, long n) {
}
}
void check_reference_gemm_is_ok(void) {
constexpr long m = 2;
constexpr long n = 2;
constexpr long k = 2;
float A[m][k] = {{1, 2}, {3, 4}};
float B[k][n] = {{5, 6}, {7, 8}};
float C[m][n] = {{666, 666}, {666, 666}};
float G[m][n] = {{19, 22}, {43, 50}};
bench(matmul(m, n, k, (float *)A, k, (float *)B, n, (float *)C, n));
check(FLAWLESS, m, n, (float *)G, n, (float *)C, n);
}
void check_transposed_blocking_gemm_is_ok(void) {
long m = 1024;
long k = 512;
long n = 80;
void test_gemm(long m, long n, long k) {
float *A = new float[m * k];
float *At = new float[k * m];
float *B = new float[k * n];
float *Bt = new float[n * k];
float *C = new float[m * n];
float *D = new float[m * n];
float *GOLD = new float[m * n];
float alpha = 1;
float beta = 0;
fill(A, m * k);
fill(B, k * n);
bench(matmul(m, n, k, A, k, B, n, C, n));
float *At = new float[k * m];
bench(transpose(m, k, A, k, At, m));
bench((gemm<8, 4>(m, n, k, At, m, B, n, D, n)));
check(FLAWLESS, m, n, C, n, D, n);
delete[] At;
delete[] D;
dgemm(0, 0, m, n, k, 1, A, k, B, n, 0, GOLD, n);
transpose(m, k, A, k, At, m);
transpose(k, n, B, n, Bt, k);
sgemm(0, 0, m, n, k, alpha, A, k, B, n, beta, C, n);
check(PRECISION, m, n, GOLD, n, C, n);
sgemm(1, 0, m, n, k, alpha, At, m, B, n, beta, C, n);
check(PRECISION, m, n, GOLD, n, C, n);
sgemm(0, 1, m, n, k, alpha, A, k, Bt, k, beta, C, n);
check(PRECISION, m, n, GOLD, n, C, n);
sgemm(1, 1, m, n, k, alpha, At, m, Bt, k, beta, C, n);
check(PRECISION, m, n, GOLD, n, C, n);
delete[] GOLD;
delete[] C;
delete[] Bt;
delete[] B;
delete[] At;
delete[] A;
}
void check_gemm_works(void) {
static long kSizes[] = {1, 2, 3, 4, 5, 6, 7, 17, 31, 33, 63, 128, 129};
is_self_testing = true;
long c = 0;
long N = sizeof(kSizes) / sizeof(kSizes[0]);
for (long i = 0; i < N; ++i) {
long m = kSizes[i];
for (long j = 0; j < N; ++j) {
long n = kSizes[N - 1 - i];
for (long k = 0; k < N; ++k) {
long K = kSizes[i];
if (c++ % 13 == 0) {
printf("testing %2ld %2ld %2ld\r", m, n, K);
}
test_gemm(m, n, K);
}
}
}
printf("\r");
is_self_testing = false;
}
long m = 2333 / 3;
long k = 577 / 3;
long n = 713 / 3;
void check_sgemm(void) {
float *A = new float[m * k];
float *At = new float[k * m];
float *B = new float[k * n];
float *Bt = new float[n * k];
float *C = new float[m * n];
double *GOLD = new double[m * n];
fill(A, m * k);
fill(B, k * n);
transpose(m, k, A, k, At, m);
transpose(k, n, B, n, Bt, k);
bench(dgemm(0, 0, m, n, k, 1, A, k, B, n, 0, GOLD, n));
bench(sgemm(0, 0, m, n, k, 1, A, k, B, n, 0, C, n));
check(PRECISION, m, n, GOLD, n, C, n);
bench(sgemm(1, 0, m, n, k, 1, At, m, B, n, 0, C, n));
check(PRECISION, m, n, GOLD, n, C, n);
bench(sgemm(0, 1, m, n, k, 1, A, k, Bt, k, 0, C, n));
check(PRECISION, m, n, GOLD, n, C, n);
bench(sgemm(1, 1, m, n, k, 1, At, m, Bt, k, 0, C, n));
check(PRECISION, m, n, GOLD, n, C, n);
delete[] GOLD;
delete[] C;
delete[] Bt;
delete[] B;
delete[] At;
delete[] A;
}
int main(int argc, char *argv[]) {
check_reference_gemm_is_ok();
check_transposed_blocking_gemm_is_ok();
check_gemm_works();
check_sgemm();
}