Remove bf16 luts

This commit is contained in:
Justine Tunney 2024-04-21 13:07:34 -07:00
parent 180bfcd8d5
commit d6892c486b
No known key found for this signature in database
GPG key ID: 52965314629936D4

27
ggml.c
View file

@ -322,12 +322,6 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
// precomputed f32 table for f16 (256 KB) (ggml-impl.h) // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
float ggml_table_f32_f16[1 << 16]; float ggml_table_f32_f16[1 << 16];
// precomputed gelu table for bf16 (128 KB)
static ggml_bf16_t ggml_table_gelu_bf16[1 << 16];
// precomputed exp table for bf16 (128 KB)
static ggml_bf16_t ggml_table_exp_bf16[1 << 16];
GGML_CALL const char * ggml_status_to_string(enum ggml_status status) { GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
switch (status) { switch (status) {
case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)"; case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
@ -1622,14 +1616,13 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
__m512 c2 = _mm512_setzero_ps(); __m512 c2 = _mm512_setzero_ps();
for (; i + 64 <= n; i += 64) { for (; i + 64 <= n; i += 64) {
c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)), c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
(__m512bh)_mm512_loadu_ps((const float *)(y + i))); (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)), c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
(__m512bh)_mm512_loadu_ps((const float *)(y + i + 32))); (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
} }
sumf += (ggml_float)_mm512_reduce_add_ps(c1); sumf += (ggml_float)_mm512_reduce_add_ps(c1);
sumf += (ggml_float)_mm512_reduce_add_ps(c2); sumf += (ggml_float)_mm512_reduce_add_ps(c2);
#undef LOAD
#elif defined(__AVX512F__) #elif defined(__AVX512F__)
#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16)) #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
__m512 c1 = _mm512_setzero_ps(); __m512 c1 = _mm512_setzero_ps();
@ -1975,16 +1968,6 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
} }
} }
inline static void ggml_vec_gelu_bf16(const int n, ggml_bf16_t * y, const ggml_bf16_t * x) {
for (int i = 0; i < n; ++i) {
union {
ggml_bf16_t f;
uint16_t i;
} u = {x[i]};
y[i] = ggml_table_gelu_bf16[u.i];
}
}
#ifdef GGML_GELU_FP16 #ifdef GGML_GELU_FP16
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
uint16_t t; uint16_t t;
@ -2889,18 +2872,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
for (int i = 0; i < (1 << 16); ++i) { for (int i = 0; i < (1 << 16); ++i) {
union { union {
uint16_t i; uint16_t u16;
ggml_fp16_t fp16; ggml_fp16_t fp16;
ggml_bf16_t bf16;
} u = {i}; } u = {i};
float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
f = GGML_BF16_TO_FP32(u.bf16);
ggml_table_gelu_bf16[i] = GGML_FP32_TO_BF16(ggml_gelu_f32(f));
ggml_table_exp_bf16[i] = GGML_FP32_TO_BF16(expf(f));
} }
const uint64_t t_end = ggml_time_us(); UNUSED(t_end); const uint64_t t_end = ggml_time_us(); UNUSED(t_end);