Implement bf16 compiler runtime library

This commit is contained in:
Justine Tunney 2024-08-01 19:42:14 -07:00
parent 9ebacb7892
commit a80ab3f8fe
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
5 changed files with 256 additions and 7 deletions

View file

@ -10,6 +10,8 @@
#include "libc/stdio/stdio.h"
#include "libc/testlib/benchmark.h"
#include "libc/x/xasprintf.h"
#include "third_party/aarch64/arm_neon.internal.h"
#include "third_party/intel/immintrin.internal.h"
#define EXPENSIVE_TESTS 0
@ -18,12 +20,11 @@
#define FASTMATH __attribute__((__optimize__("-O3,-ffast-math")))
#define PORTABLE __target_clones("avx512f,avx")
static unsigned long long lcg = 1;
int rand32(void) {
/* Knuth, D.E., "The Art of Computer Programming," Vol 2,
Seminumerical Algorithms, Third Edition, Addison-Wesley, 1998,
p. 106 (line 26) & p. 108 */
static unsigned long long lcg = 1;
lcg *= 6364136223846793005;
lcg += 1442695040888963407;
return lcg >> 32;
@ -122,6 +123,34 @@ float fdotf_recursive(const float *A, const float *B, size_t n) {
}
}
optimizespeed float fdotf_intrin(const float *A, const float *B, size_t n) {
size_t i = 0;
#ifdef __AVX512F__
__m512 vec[CHUNK] = {};
for (; i + CHUNK * 16 <= n; i += CHUNK * 16)
for (int j = 0; j < CHUNK; ++j)
vec[j] = _mm512_fmadd_ps(_mm512_loadu_ps(A + i + j * 16),
_mm512_loadu_ps(B + i + j * 16), vec[j]);
float res = 0;
for (int j = 0; j < CHUNK; ++j)
res += _mm512_reduce_add_ps(vec[j]);
#elif defined(__aarch64__)
float32x4_t vec[CHUNK] = {};
for (; i + CHUNK * 4 <= n; i += CHUNK * 4)
for (int j = 0; j < CHUNK; ++j)
vec[j] =
vfmaq_f32(vec[j], vld1q_f32(A + i + j * 4), vld1q_f32(B + i + j * 4));
float res = 0;
for (int j = 0; j < CHUNK; ++j)
res += vaddvq_f32(vec[j]);
#else
float res = 0;
#endif
for (; i < n; ++i)
res += A[i] * B[i];
return res;
}
FASTMATH float fdotf_ruler(const float *A, const float *B, size_t n) {
int rule, step = 2;
size_t chunk, sp = 0;
@ -179,6 +208,8 @@ void test_fdotf_ruler(void) {
}
PORTABLE float fdotf_hefty(const float *A, const float *B, size_t n) {
if (1)
return 0;
unsigned i, par, len = 0;
float sum, res[n / CHUNK + 1];
for (res[0] = i = 0; i + CHUNK <= n; i += CHUNK)
@ -244,7 +275,7 @@ int main() {
#if EXPENSIVE_TESTS
size_t n = 512 * 1024;
#else
size_t n = 1024;
size_t n = 4096;
#endif
float *A = new float[n];
@ -253,22 +284,24 @@ int main() {
A[i] = numba();
B[i] = numba();
}
float kahan, naive, dubble, recursive, hefty, ruler;
float kahan, naive, dubble, recursive, ruler, intrin;
test_fdotf_naive();
test_fdotf_hefty();
// test_fdotf_hefty();
test_fdotf_ruler();
BENCHMARK(20, 1, (kahan = barrier(fdotf_kahan(A, B, n))));
BENCHMARK(20, 1, (dubble = barrier(fdotf_dubble(A, B, n))));
BENCHMARK(20, 1, (naive = barrier(fdotf_naive(A, B, n))));
BENCHMARK(20, 1, (recursive = barrier(fdotf_recursive(A, B, n))));
BENCHMARK(20, 1, (intrin = barrier(fdotf_intrin(A, B, n))));
BENCHMARK(20, 1, (ruler = barrier(fdotf_ruler(A, B, n))));
BENCHMARK(20, 1, (hefty = barrier(fdotf_hefty(A, B, n))));
// BENCHMARK(20, 1, (hefty = barrier(fdotf_hefty(A, B, n))));
printf("dubble = %f (%g)\n", dubble, fabs(dubble - dubble));
printf("kahan = %f (%g)\n", kahan, fabs(kahan - dubble));
printf("naive = %f (%g)\n", naive, fabs(naive - dubble));
printf("recursive = %f (%g)\n", recursive, fabs(recursive - dubble));
printf("intrin = %f (%g)\n", intrin, fabs(intrin - dubble));
printf("ruler = %f (%g)\n", ruler, fabs(ruler - dubble));
printf("hefty = %f (%g)\n", hefty, fabs(hefty - dubble));
// printf("hefty = %f (%g)\n", hefty, fabs(hefty - dubble));
delete[] B;
delete[] A;