mirror of
https://github.com/jart/cosmopolitan.git
synced 2025-05-25 23:02:27 +00:00
Implement bf16 compiler runtime library
This commit is contained in:
parent
9ebacb7892
commit
a80ab3f8fe
5 changed files with 256 additions and 7 deletions
|
@ -10,6 +10,8 @@
|
|||
#include "libc/stdio/stdio.h"
|
||||
#include "libc/testlib/benchmark.h"
|
||||
#include "libc/x/xasprintf.h"
|
||||
#include "third_party/aarch64/arm_neon.internal.h"
|
||||
#include "third_party/intel/immintrin.internal.h"
|
||||
|
||||
#define EXPENSIVE_TESTS 0
|
||||
|
||||
|
@ -18,12 +20,11 @@
|
|||
#define FASTMATH __attribute__((__optimize__("-O3,-ffast-math")))
|
||||
#define PORTABLE __target_clones("avx512f,avx")
|
||||
|
||||
static unsigned long long lcg = 1;
|
||||
|
||||
int rand32(void) {
|
||||
/* Knuth, D.E., "The Art of Computer Programming," Vol 2,
|
||||
Seminumerical Algorithms, Third Edition, Addison-Wesley, 1998,
|
||||
p. 106 (line 26) & p. 108 */
|
||||
static unsigned long long lcg = 1;
|
||||
lcg *= 6364136223846793005;
|
||||
lcg += 1442695040888963407;
|
||||
return lcg >> 32;
|
||||
|
@ -122,6 +123,34 @@ float fdotf_recursive(const float *A, const float *B, size_t n) {
|
|||
}
|
||||
}
|
||||
|
||||
optimizespeed float fdotf_intrin(const float *A, const float *B, size_t n) {
|
||||
size_t i = 0;
|
||||
#ifdef __AVX512F__
|
||||
__m512 vec[CHUNK] = {};
|
||||
for (; i + CHUNK * 16 <= n; i += CHUNK * 16)
|
||||
for (int j = 0; j < CHUNK; ++j)
|
||||
vec[j] = _mm512_fmadd_ps(_mm512_loadu_ps(A + i + j * 16),
|
||||
_mm512_loadu_ps(B + i + j * 16), vec[j]);
|
||||
float res = 0;
|
||||
for (int j = 0; j < CHUNK; ++j)
|
||||
res += _mm512_reduce_add_ps(vec[j]);
|
||||
#elif defined(__aarch64__)
|
||||
float32x4_t vec[CHUNK] = {};
|
||||
for (; i + CHUNK * 4 <= n; i += CHUNK * 4)
|
||||
for (int j = 0; j < CHUNK; ++j)
|
||||
vec[j] =
|
||||
vfmaq_f32(vec[j], vld1q_f32(A + i + j * 4), vld1q_f32(B + i + j * 4));
|
||||
float res = 0;
|
||||
for (int j = 0; j < CHUNK; ++j)
|
||||
res += vaddvq_f32(vec[j]);
|
||||
#else
|
||||
float res = 0;
|
||||
#endif
|
||||
for (; i < n; ++i)
|
||||
res += A[i] * B[i];
|
||||
return res;
|
||||
}
|
||||
|
||||
FASTMATH float fdotf_ruler(const float *A, const float *B, size_t n) {
|
||||
int rule, step = 2;
|
||||
size_t chunk, sp = 0;
|
||||
|
@ -179,6 +208,8 @@ void test_fdotf_ruler(void) {
|
|||
}
|
||||
|
||||
PORTABLE float fdotf_hefty(const float *A, const float *B, size_t n) {
|
||||
if (1)
|
||||
return 0;
|
||||
unsigned i, par, len = 0;
|
||||
float sum, res[n / CHUNK + 1];
|
||||
for (res[0] = i = 0; i + CHUNK <= n; i += CHUNK)
|
||||
|
@ -244,7 +275,7 @@ int main() {
|
|||
#if EXPENSIVE_TESTS
|
||||
size_t n = 512 * 1024;
|
||||
#else
|
||||
size_t n = 1024;
|
||||
size_t n = 4096;
|
||||
#endif
|
||||
|
||||
float *A = new float[n];
|
||||
|
@ -253,22 +284,24 @@ int main() {
|
|||
A[i] = numba();
|
||||
B[i] = numba();
|
||||
}
|
||||
float kahan, naive, dubble, recursive, hefty, ruler;
|
||||
float kahan, naive, dubble, recursive, ruler, intrin;
|
||||
test_fdotf_naive();
|
||||
test_fdotf_hefty();
|
||||
// test_fdotf_hefty();
|
||||
test_fdotf_ruler();
|
||||
BENCHMARK(20, 1, (kahan = barrier(fdotf_kahan(A, B, n))));
|
||||
BENCHMARK(20, 1, (dubble = barrier(fdotf_dubble(A, B, n))));
|
||||
BENCHMARK(20, 1, (naive = barrier(fdotf_naive(A, B, n))));
|
||||
BENCHMARK(20, 1, (recursive = barrier(fdotf_recursive(A, B, n))));
|
||||
BENCHMARK(20, 1, (intrin = barrier(fdotf_intrin(A, B, n))));
|
||||
BENCHMARK(20, 1, (ruler = barrier(fdotf_ruler(A, B, n))));
|
||||
BENCHMARK(20, 1, (hefty = barrier(fdotf_hefty(A, B, n))));
|
||||
// BENCHMARK(20, 1, (hefty = barrier(fdotf_hefty(A, B, n))));
|
||||
printf("dubble = %f (%g)\n", dubble, fabs(dubble - dubble));
|
||||
printf("kahan = %f (%g)\n", kahan, fabs(kahan - dubble));
|
||||
printf("naive = %f (%g)\n", naive, fabs(naive - dubble));
|
||||
printf("recursive = %f (%g)\n", recursive, fabs(recursive - dubble));
|
||||
printf("intrin = %f (%g)\n", intrin, fabs(intrin - dubble));
|
||||
printf("ruler = %f (%g)\n", ruler, fabs(ruler - dubble));
|
||||
printf("hefty = %f (%g)\n", hefty, fabs(hefty - dubble));
|
||||
// printf("hefty = %f (%g)\n", hefty, fabs(hefty - dubble));
|
||||
delete[] B;
|
||||
delete[] A;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue