From 105f818d45efd17bf389a6013f036e22492cd7db Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Wed, 3 May 2023 23:26:39 +0800 Subject: [PATCH] integrated new version of rwkv from upstream --- Makefile | 13 +- otherarch/ggml_rwkv.c | 11588 ---------------------------------------- otherarch/ggml_rwkv.h | 645 --- otherarch/otherarch.h | 57 - otherarch/rwkv.cpp | 476 +- otherarch/rwkv.h | 15 +- 6 files changed, 315 insertions(+), 12479 deletions(-) delete mode 100644 otherarch/ggml_rwkv.c delete mode 100644 otherarch/ggml_rwkv.h diff --git a/Makefile b/Makefile index 75211e4ae..7091f6355 100644 --- a/Makefile +++ b/Makefile @@ -235,9 +235,6 @@ ggml_v1.o: otherarch/ggml_v1.c otherarch/ggml_v1.h ggml_v1_noavx2.o: otherarch/ggml_v1.c otherarch/ggml_v1.h $(CC) $(CFLAGS) $(BONUSCFLAGS1) -c $< -o $@ -ggml_rwkv.o: otherarch/ggml_rwkv.c otherarch/ggml_rwkv.h - $(CC) $(CFLAGS) $(BONUSCFLAGS1) $(BONUSCFLAGS2) -c $< -o $@ - llama.o: llama.cpp llama.h llama-util.h $(CXX) $(CXXFLAGS) -c $< -o $@ @@ -263,19 +260,19 @@ main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS) @echo '==== Run ./main -h for help. ====' @echo -koboldcpp: ggml.o ggml_rwkv.o ggml_v1.o expose.o common.o gpttype_adapter.o +koboldcpp: ggml.o ggml_v1.o expose.o common.o gpttype_adapter.o $(DEFAULT_BUILD) -koboldcpp_openblas: ggml_openblas.o ggml_rwkv.o ggml_v1.o expose.o common.o gpttype_adapter.o +koboldcpp_openblas: ggml_openblas.o ggml_v1.o expose.o common.o gpttype_adapter.o $(OPENBLAS_BUILD) -koboldcpp_noavx2: ggml_noavx2.o ggml_rwkv.o ggml_v1_noavx2.o expose.o common.o gpttype_adapter.o +koboldcpp_noavx2: ggml_noavx2.o ggml_v1_noavx2.o expose.o common.o gpttype_adapter.o $(NOAVX2_BUILD) -koboldcpp_openblas_noavx2: ggml_openblas_noavx2.o ggml_rwkv.o ggml_v1_noavx2.o expose.o common.o gpttype_adapter.o +koboldcpp_openblas_noavx2: ggml_openblas_noavx2.o ggml_v1_noavx2.o expose.o common.o gpttype_adapter.o $(OPENBLAS_NOAVX2_BUILD) -koboldcpp_clblast: ggml_clblast.o ggml_rwkv.o ggml_v1.o expose.o common.o gpttype_adapter.o ggml-opencl.o +koboldcpp_clblast: ggml_clblast.o ggml_v1.o expose.o common.o gpttype_adapter.o ggml-opencl.o $(CLBLAST_BUILD) quantize_llama: examples/quantize/quantize.cpp ggml.o llama.o diff --git a/otherarch/ggml_rwkv.c b/otherarch/ggml_rwkv.c deleted file mode 100644 index 7fa6a7d2f..000000000 --- a/otherarch/ggml_rwkv.c +++ /dev/null @@ -1,11588 +0,0 @@ -// Defines CLOCK_MONOTONIC and asprintf on Linux -#define _GNU_SOURCE - -#include "ggml_rwkv.h" - -#if defined(_MSC_VER) || defined(__MINGW32__) -#include // using malloc.h with MSC/MINGW -#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) -#include -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// if C99 - static_assert is noop -// ref: https://stackoverflow.com/a/53923785/4039976 -#ifndef static_assert -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif - -// https://gist.github.com/rygorous/2144712 -// Public domain, by Fabian "ryg" Giesen -inline static float ggml_rwkv_half_to_float_reference(uint16_t value) { - union FP32 { - uint32_t u; - float f; - }; - - const union FP32 magic = { (254UL - 15UL) << 23 }; - const union FP32 was_inf_nan = { (127UL + 16UL) << 23 }; - - union FP32 out; - - // Exponent/mantissa bits - out.u = (value & 0x7FFFU) << 13; - // Exponent adjust - out.f *= magic.f; - - // Make sure Inf/NaN survive - if (out.f >= was_inf_nan.f) { - out.u |= 255UL << 23; - } - - // Sign bit - out.u |= (value & 0x8000UL) << 16; - - return out.f; -} - -#if defined _MSC_VER || defined(__MINGW32__) - -#if !defined(__MINGW32__) -#include -#else -// ref: https://github.com/ggerganov/whisper.cpp/issues/168 -#include -#endif - -typedef volatile LONG atomic_int; -typedef atomic_int atomic_bool; - -static void atomic_store(atomic_int* ptr, LONG val) { - InterlockedExchange(ptr, val); -} -static LONG atomic_load(atomic_int* ptr) { - return InterlockedCompareExchange(ptr, 0, 0); -} -static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { - return InterlockedExchangeAdd(ptr, inc); -} -static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { - return atomic_fetch_add(ptr, -(dec)); -} - -typedef HANDLE pthread_t; - -typedef DWORD thread_ret_t; -static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { - HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); - if (handle == NULL) - { - return EAGAIN; - } - - *out = handle; - return 0; -} - -static int pthread_join(pthread_t thread, void* unused) { - return (int) WaitForSingleObject(thread, INFINITE); -} - -static int sched_yield (void) { - Sleep (0); - return 0; -} -#else -#include -#include - -typedef void* thread_ret_t; -#endif - -// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 -#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) -#ifndef __FMA__ -#define __FMA__ -#endif -#ifndef __F16C__ -#define __F16C__ -#endif -#ifndef __SSE3__ -#define __SSE3__ -#endif -#endif - -#ifdef __HAIKU__ -#define static_assert(cond, msg) _Static_assert(cond, msg) -#endif - -#define GGML_RWKV_MLOCK_SUPPORT 0 - -#ifdef __has_include - #if __has_include() - #undef GGML_RWKV_MLOCK_SUPPORT - #define GGML_RWKV_MLOCK_SUPPORT 1 - #include - #endif -#endif - - -/*#define GGML_RWKV_PERF*/ -#define GGML_RWKV_DEBUG 0 -#define GGML_RWKV_GELU_FP16 -#define GGML_RWKV_SILU_FP16 - -#define GGML_RWKV_SOFT_MAX_UNROLL 4 -#define GGML_RWKV_VEC_DOT_UNROLL 2 - -#ifdef GGML_RWKV_USE_ACCELERATE -// uncomment to use vDSP for soft max computation -// note: not sure if it is actually faster -//#define GGML_RWKV_SOFT_MAX_ACCELERATE -#endif - -#if UINTPTR_MAX == 0xFFFFFFFF - #define GGML_RWKV_MEM_ALIGN 4 -#else - #define GGML_RWKV_MEM_ALIGN 16 -#endif - -#define UNUSED(x) (void)(x) -#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) - -#define GGML_RWKV_ASSERT(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "GGML_RWKV_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ - abort(); \ - } \ - } while (0) - -#ifdef GGML_RWKV_USE_ACCELERATE -#include -#elif GGML_RWKV_USE_OPENBLAS -#include -#endif - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -// floating point type used to accumulate sums -typedef double ggml_rwkv_float; - -// 16-bit float -// on Arm, we use __fp16 -// on x86, we use uint16_t -#ifdef __ARM_NEON - -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - -#define GGML_RWKV_COMPUTE_FP16_TO_FP32(x) ((float) (x)) -#define GGML_RWKV_COMPUTE_FP32_TO_FP16(x) (x) - -#define GGML_RWKV_FP16_TO_FP32(x) ((float) (x)) -#define GGML_RWKV_FP32_TO_FP16(x) (x) - -#else - -#ifdef __wasm_simd128__ -#include -#else -#ifdef __POWER9_VECTOR__ -#include -#undef bool -#define bool _Bool -#else -#include -#endif -#endif - -#ifdef __F16C__ - -#ifdef _MSC_VER -#define GGML_RWKV_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) -#define GGML_RWKV_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) -#else -#define GGML_RWKV_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) -#define GGML_RWKV_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) -#endif - -#elif defined(__POWER9_VECTOR__) - -#define GGML_RWKV_COMPUTE_FP16_TO_FP32(x) ggml_rwkv_compute_fp16_to_fp32(x) -#define GGML_RWKV_COMPUTE_FP32_TO_FP16(x) ggml_rwkv_compute_fp32_to_fp16(x) -/* the inline asm below is about 12% faster than the lookup method */ -#define GGML_RWKV_FP16_TO_FP32(x) GGML_RWKV_COMPUTE_FP16_TO_FP32(x) -#define GGML_RWKV_FP32_TO_FP16(x) GGML_RWKV_COMPUTE_FP32_TO_FP16(x) - -static inline float ggml_rwkv_compute_fp16_to_fp32(ggml_rwkv_fp16_t h) { - register float f; - register double d; - __asm__( - "mtfprd %0,%2\n" - "xscvhpdp %0,%0\n" - "frsp %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=f"(f): - /* in */ "r"(h)); - return f; -} - -static inline ggml_rwkv_fp16_t ggml_rwkv_compute_fp32_to_fp16(float f) { - register double d; - register ggml_rwkv_fp16_t r; - __asm__( /* xscvdphp can work on double or single precision */ - "xscvdphp %0,%2\n" - "mffprd %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=r"(r): - /* in */ "f"(f)); - return r; -} - -#else - -// FP16 <-> FP32 -// ref: https://github.com/Maratyszcza/FP16 - -static inline float fp32_from_bits(uint32_t w) { - union { - uint32_t as_bits; - float as_value; - } fp32; - fp32.as_bits = w; - return fp32.as_value; -} - -static inline uint32_t fp32_to_bits(float f) { - union { - float as_value; - uint32_t as_bits; - } fp32; - fp32.as_value = f; - return fp32.as_bits; -} - -static inline float ggml_rwkv_compute_fp16_to_fp32(ggml_rwkv_fp16_t h) { - const uint32_t w = (uint32_t) h << 16; - const uint32_t sign = w & UINT32_C(0x80000000); - const uint32_t two_w = w + w; - - const uint32_t exp_offset = UINT32_C(0xE0) << 23; -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float exp_scale = 0x1.0p-112f; -#else - const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); -#endif - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = sign | - (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); -} - -static inline ggml_rwkv_fp16_t ggml_rwkv_compute_fp32_to_fp16(float f) { -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float scale_to_inf = 0x1.0p+112f; - const float scale_to_zero = 0x1.0p-110f; -#else - const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); - const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); -#endif - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); -} - -#define GGML_RWKV_COMPUTE_FP16_TO_FP32(x) ggml_rwkv_compute_fp16_to_fp32(x) -#define GGML_RWKV_COMPUTE_FP32_TO_FP16(x) ggml_rwkv_compute_fp32_to_fp16(x) - -#endif // __F16C__ - -#endif // __ARM_NEON - -// -// global data -// - -// precomputed gelu table for f16 (128 KB) -static ggml_rwkv_fp16_t table_gelu_f16[1 << 16]; - -// precomputed silu table for f16 (128 KB) -static ggml_rwkv_fp16_t table_silu_f16[1 << 16]; - -// precomputed exp table for f16 (128 KB) -static ggml_rwkv_fp16_t table_exp_f16[1 << 16]; - -// precomputed f32 table for f16 (256 KB) -static float table_f32_f16[1 << 16]; - -// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_rwkv_lookup_fp16_to_fp32, -// so we define GGML_RWKV_FP16_TO_FP32 and GGML_RWKV_FP32_TO_FP16 elsewhere for NEON. -// This is also true for POWER9. -#if !defined(GGML_RWKV_FP16_TO_FP32) || !defined(GGML_RWKV_FP32_TO_FP16) - - - -inline static float ggml_rwkv_lookup_fp16_to_fp32(ggml_rwkv_fp16_t f) { - // For some reason, lookup table does not work on my machine. - // Replaced lookup with working reference code. - // TODO This must be properly debugged and fixed - return ggml_rwkv_half_to_float_reference(f); -} - -#define GGML_RWKV_FP16_TO_FP32(x) ggml_rwkv_lookup_fp16_to_fp32(x) -#define GGML_RWKV_FP32_TO_FP16(x) GGML_RWKV_COMPUTE_FP32_TO_FP16(x) - -#endif - -// note: do not use these inside ggml.c -// these are meant to be used via the ggml.h API -float ggml_rwkv_fp16_to_fp32(ggml_rwkv_fp16_t x) { - return (float) GGML_RWKV_FP16_TO_FP32(x); -} - -ggml_rwkv_fp16_t ggml_rwkv_fp32_to_fp16(float x) { - return GGML_RWKV_FP32_TO_FP16(x); -} - -// -// timing -// - -#if defined(_MSC_VER) || defined(__MINGW32__) -static int64_t timer_freq; -void ggml_rwkv_time_init(void) { - LARGE_INTEGER frequency; - QueryPerformanceFrequency(&frequency); - timer_freq = frequency.QuadPart; -} -int64_t ggml_rwkv_time_ms(void) { - LARGE_INTEGER t; - QueryPerformanceCounter(&t); - return (t.QuadPart * 1000) / timer_freq; -} -int64_t ggml_rwkv_time_us(void) { - LARGE_INTEGER t; - QueryPerformanceCounter(&t); - return (t.QuadPart * 1000000) / timer_freq; -} -#else -void ggml_rwkv_time_init(void) {} -int64_t ggml_rwkv_time_ms(void) { - struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); - return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; -} - -int64_t ggml_rwkv_time_us(void) { - struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); - return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; -} -#endif - -int64_t ggml_rwkv_cycles(void) { - return clock(); -} - -int64_t ggml_rwkv_cycles_per_ms(void) { - return CLOCKS_PER_SEC/1000; -} - -#ifdef GGML_RWKV_PERF -#define ggml_rwkv_perf_time_ms() ggml_rwkv_time_ms() -#define ggml_rwkv_perf_time_us() ggml_rwkv_time_us() -#define ggml_rwkv_perf_cycles() ggml_rwkv_cycles() -#define ggml_rwkv_perf_cycles_per_ms() ggml_rwkv_cycles_per_ms() -#else -#define ggml_rwkv_perf_time_ms() 0 -#define ggml_rwkv_perf_time_us() 0 -#define ggml_rwkv_perf_cycles() 0 -#define ggml_rwkv_perf_cycles_per_ms() 0 -#endif - -// -// cache line -// - -#if defined(__cpp_lib_hardware_interference_size) -#define CACHE_LINE_SIZE hardware_destructive_interference_size -#else -#if defined(__POWER9_VECTOR__) -#define CACHE_LINE_SIZE 128 -#else -#define CACHE_LINE_SIZE 64 -#endif -#endif - -static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); - -// -// quantization -// - -#define QK 32 - -// AVX routines provided by GH user Const-me -// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600 -#if __AVX2__ || __AVX512F__ -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytesFromNibbles( const uint8_t* rsi ) -{ - // Load 16 bytes from memory - __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi ); - - // Expand bytes into uint16_t values - __m256i bytes = _mm256_cvtepu8_epi16( tmp ); - - // Unpack values into individual bytes - const __m256i lowMask = _mm256_set1_epi8( 0xF ); - __m256i high = _mm256_andnot_si256( lowMask, bytes ); - __m256i low = _mm256_and_si256( lowMask, bytes ); - high = _mm256_slli_epi16( high, 4 ); - bytes = _mm256_or_si256( low, high ); - return bytes; -} - -static inline __m128i packNibbles( __m256i bytes ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m256i lowByte = _mm256_set1_epi16( 0xFF ); - __m256i high = _mm256_andnot_si256( lowByte, bytes ); - __m256i low = _mm256_and_si256( lowByte, bytes ); - high = _mm256_srli_epi16( high, 4 ); - bytes = _mm256_or_si256( low, high ); - - // Compress uint16_t lanes into bytes - __m128i r0 = _mm256_castsi256_si128( bytes ); - __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); - return _mm_packus_epi16( r0, r1 ); -} -#endif - -// method 5 -// blocks of QK elements -// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) -typedef struct { - float d; // delta - uint8_t qs[QK / 2]; // nibbles / quants -} block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding"); - -// method 4 -// blocks of QK elements -// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) -typedef struct { - float d; - float m; - uint8_t qs[QK / 2]; // nibbles / quants -} block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); - -// Method 4 with better outlier handling. -typedef struct { - ggml_rwkv_fp16_t d; - ggml_rwkv_fp16_t m; - // We need only 5 bits for the in-block index, so 16 bits is overkill. - // TODO Optimize if possible - uint16_t outlier_index; - ggml_rwkv_fp16_t outlier_value; - // Nibbles / quants. - uint8_t qs[QK / 2]; -} block_q4_1_o; -static_assert(sizeof(block_q4_1_o) == 8 + QK / 2, "wrong q4_1_o block size/padding"); - -// reference implementation for deterministic creation of model files -static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { - assert(k % QK == 0); - const int nb = k / QK; - - uint8_t pp[QK/2]; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - for (int l = 0; l < QK; l++) { - const float v = x[i*QK + l]; - amax = MAX(amax, fabsf(v)); - } - - const float d = amax / ((1 << 3) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - for (int l = 0; l < QK; l += 2) { - const float v0 = x[i*QK + l + 0]*id; - const float v1 = x[i*QK + l + 1]*id; - - const uint8_t vi0 = (int8_t)roundf(v0) + 8; - const uint8_t vi1 = (int8_t)roundf(v1) + 8; - - assert(vi0 >= 0 && vi0 < 16); - assert(vi1 >= 0 && vi1 < 16); - - pp[l/2] = vi0 | (vi1 << 4); - } - - memcpy(y[i].qs, pp, sizeof(pp)); - } -} - -static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) { - assert(k % QK == 0); - const int nb = k / QK; - - block_q4_0 * restrict y = vy; - -#if defined(__POWER9_VECTOR__) - const vector float v85 = vec_splats(8.5f); - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - vector float srcv [8]; - vector float asrcv[8]; - vector float amaxv[8]; - - for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l); - for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]); - - for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]); - //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]); - amaxv[0] = vec_max(amaxv[0], amaxv[2]); - amaxv[4] = vec_max(amaxv[4], amaxv[6]); - //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]); - amaxv[0] = vec_max(amaxv[0], amaxv[4]); - - amax = MAX( - MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 3) - 1); - const float id = d ? 1.0/d : 0.0; - - y[i].d = d; - - const vector float vid = vec_splats(id); - uint8_t * restrict pb = y[i].qs; - for (int l = 0; l < 8; l++) { - const vector float vf = vec_madd(srcv[l], vid, v85); - const vector signed int vi = vec_signed(vf); - - pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4); - pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4); - } - } -#elif __ARM_NEON - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); - for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); - - for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); - for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); - for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); - - // absolute max - const float amax = MAX( - MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)), - MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); - - const float d = amax / ((1 << 3) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - for (int l = 0; l < 8; l++) { - const float32x4_t v = vmulq_n_f32(srcv[l], id); - const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); - const int32x4_t vi = vcvtq_s32_f32(vf); - - y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); - y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); - } - } -#elif defined(__AVX2__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 7.0f; - y[i].d = d; - const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ] - const __m256i off = _mm256_set1_epi8( 8 ); - i0 = _mm256_add_epi8( i0, off ); - - // Compress the vector into 4 bit/value, and store - __m128i res = packNibbles( i0 ); - _mm_storeu_si128( ( __m128i* )y[i].qs, res ); - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l); - for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]); - - for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]); - for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]); - for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]); - - amax = MAX( - MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 3) - 1); - const float id = d ? 1.0/d : 0.0; - - y[i].d = d; - - for (int l = 0; l < 8; l++) { - const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); - const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf); - - y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4); - y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); - } - } -#else - // scalar - quantize_row_q4_0_reference(x, y, k); -#endif -} - -static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) { - assert(k % QK == 0); - const int nb = k / QK; - - block_q4_1 * restrict y = vy; - - uint8_t pp[QK/2]; - - for (int i = 0; i < nb; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int l = 0; l < QK; l++) { - const float v = x[i*QK + l]; - if (v < min) min = v; - if (v > max) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - y[i].m = min; - - for (int l = 0; l < QK; l += 2) { - const float v0 = (x[i*QK + l + 0] - min)*id; - const float v1 = (x[i*QK + l + 1] - min)*id; - - const uint8_t vi0 = roundf(v0); - const uint8_t vi1 = roundf(v1); - - assert(vi0 >= 0 && vi0 < 16); - assert(vi1 >= 0 && vi1 < 16); - - pp[l/2] = vi0 | (vi1 << 4); - } - - memcpy(y[i].qs, pp, sizeof(pp)); - } -} - -static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) { - assert(k % QK == 0); - - const int nb = k / QK; - - block_q4_1 * restrict y = vy; - -#if defined(__AVX2__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max for the block - __m256 vmax; - vmax = _mm256_max_ps( v0, v1 ); - vmax = _mm256_max_ps( vmax, v2 ); - vmax = _mm256_max_ps( vmax, v3 ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Compute min for the block - __m256 vmin; - vmin = _mm256_min_ps( v0, v1 ); - vmin = _mm256_min_ps( vmin, v2 ); - vmin = _mm256_min_ps( vmin, v3 ); - - __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) ); - min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) ); - min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) ); - const float minScalar = _mm_cvtss_f32( min4 ); - - // Quantize these floats - const float d = (maxScalar - minScalar) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].m = minScalar; - y[i].d = d; - - // x = (x-min)*id - const __m256 mul = _mm256_set1_ps( id ); - const __m256 off = _mm256_set1_ps( minScalar ); - v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul ); - v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul ); - v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul ); - v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - // Compress the vector into 4 bit/value, and store - __m128i res = packNibbles( i0 ); - _mm_storeu_si128( ( __m128i* )y[i].qs, res ); - } -#elif __ARM_NEON - for (int i = 0; i < nb; i++) { - float32x4_t srcv[8]; - float32x4_t minv[8]; - float32x4_t maxv[8]; - - for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); - - for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]); - for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]); - for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]); - - for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]); - for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]); - for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]); - - const float min = vminvq_f32(minv[0]); - const float max = vmaxvq_f32(maxv[0]); - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - y[i].m = min; - - const float32x4_t minv0 = vdupq_n_f32(min); - - for (int l = 0; l < 8; l++) { - const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id); - const int32x4_t vi = vcvtq_s32_f32(v); - - y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); - y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); - } - } -#else - // scalar - quantize_row_q4_1_reference(x, vy, k); -#endif -} - -static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { - assert(k % QK == 0); - const int nb = k / QK; - - const block_q4_0 * restrict x = vx; - -#if defined(__AVX2__) - for (int i = 0; i < nb; i++) { - // scale factor - const __m256 d_v = _mm256_broadcast_ss(&x[i].d); - - const uint8_t * restrict pp = x[i].qs; - - for (int l = 0; l < QK; l += 32) { - // Load 32x4-bit integers into 32x8-bit integers - __m256i vx8 = bytesFromNibbles(pp+l/2); - - // Subtract 8 from the integers - vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8)); - - // Convert to 16-bit int - const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); - const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); - - // Convert to 32-bit int -> float 32 - const __m256 vf[4] = { - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) - }; - - // Scale and store - for (int j = 0; j < 4; j++) { - const __m256 result = _mm256_mul_ps(vf[j], d_v); - _mm256_storeu_ps(y + i * QK + l + j*8, result); - } - } - } -#elif defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - const float32x4_t vd = vdupq_n_f32(x[i].d); - - const uint8_t * restrict pp = x[i].qs; - - for (int l = 0; l < QK; l += 16) { - // Load 16x4-bit integers into 8x8-bit integers - const uint8x8_t v8 = vld1_u8(pp + l/2); - - // Expand 4-bit qs to 8-bit bytes - const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); - const uint8x8_t v1 = vshr_n_u8(v8, 4); - - // Convert to signed 8-bit integers - const int8x8_t vs_0 = vreinterpret_s8_u8(v0); - const int8x8_t vs_1 = vreinterpret_s8_u8(v1); - - // Subtract 8 from each byte - const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8)); - const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8)); - - // Interleave and combine - const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1); - const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1); - - const int8x16_t vq = vcombine_s8(vx_0, vx_1); - - // convert to 2x int16x8_t - const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq)); - const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq)); - - // convert to 4x float32x4_t - const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0))); - const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0))); - const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1))); - const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1))); - - // Multiply by d - const float32x4_t r0 = vmulq_f32(vf_0, vd); - const float32x4_t r1 = vmulq_f32(vf_1, vd); - const float32x4_t r2 = vmulq_f32(vf_2, vd); - const float32x4_t r3 = vmulq_f32(vf_3, vd); - - // Store - vst1q_f32(y + i*QK + l + 0, r0); - vst1q_f32(y + i*QK + l + 4, r1); - vst1q_f32(y + i*QK + l + 8, r2); - vst1q_f32(y + i*QK + l + 12, r3); - } - } -#else - // scalar - for (int i = 0; i < nb; i++) { - const float d = x[i].d; - - const uint8_t * restrict pp = x[i].qs; - - for (int l = 0; l < QK; l += 2) { - const uint8_t vi = pp[l/2]; - - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; - - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; - - //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1); - - y[i*QK + l + 0] = v0; - y[i*QK + l + 1] = v1; - - assert(!isnan(y[i*QK + l + 0])); - assert(!isnan(y[i*QK + l + 1])); - } - } -#endif -} - -static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) { - assert(k % QK == 0); - const int nb = k / QK; - - const block_q4_1 * restrict x = vx; - -#if defined(__AVX2__) - for (int i = 0; i < nb; i++) { - const __m256 d_v = _mm256_broadcast_ss(&x[i].d); - const __m256 d_m = _mm256_broadcast_ss(&x[i].m); - - const uint8_t * restrict pp = x[i].qs; - - for (int l = 0; l < QK; l += 32) { - // Load 32x4-bit integers into 32x8-bit integers - __m256i vx8 = bytesFromNibbles(pp+l/2); - - // Convert to 16-bit int - const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); - const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); - - // Convert to 32-bit int -> float 32 - const __m256 vf[4] = { - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) - }; - - // Scale, add m and store - for (int j = 0; j < 4; j++) { - const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); - _mm256_storeu_ps(y + i * QK + l + j*8, result); - } - } - } -// #elif defined(__ARM_NEON) -// for (int i = 0; i < nb; i++) { -// const float32x4_t vd = vdupq_n_f32(x[i].d); -// const float32x4_t vm = vdupq_n_f32(x[i].m); - -// const uint8_t * restrict pp = x[i].qs; - -// for (int l = 0; l < QK; l += 16) { -// // Load 16x4-bit integers into 8x8-bit integers -// const uint8x8_t v8 = vld1_u8(pp + l/2); - -// // Expand 4-bit qs to 8-bit bytes -// const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); -// const uint8x8_t v1 = vshr_n_u8(v8, 4); - -// // Interleave and combine -// const uint8x8_t vx_0 = vzip1_u8(v0, v1); -// const uint8x8_t vx_1 = vzip2_u8(v0, v1); - -// const uint8x16_t vq = vcombine_u8(vx_0, vx_1); - -// // convert to 2x uint16x8_t -// const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); -// const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); - -// // convert to 4x float32x4_t -// const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); -// const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); -// const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); -// const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); - -// // multiply by d and add m -// const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); -// const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); -// const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); -// const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); - -// // Store -// vst1q_f32(y + i*QK + l + 0, r0); -// vst1q_f32(y + i*QK + l + 4, r1); -// vst1q_f32(y + i*QK + l + 8, r2); -// vst1q_f32(y + i*QK + l + 12, r3); -// } -// } -#else - for (int i = 0; i < nb; i++) { - const float d = x[i].d; - const float m = x[i].m; - - const uint8_t * restrict pp = x[i].qs; - - for (int l = 0; l < QK; l += 2) { - const uint8_t vi = pp[l/2]; - - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; - - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; - - y[i*QK + l + 0] = v0; - y[i*QK + l + 1] = v1; - - assert(!isnan(y[i*QK + l + 0])); - assert(!isnan(y[i*QK + l + 1])); - } - } -#endif -} - -// Q4_1_O - -static inline void quantize_row_q4_1_o_reference_single_block(const float * restrict x, block_q4_1_o * restrict block) { - // An outlier is just the absmax element in the block. - // We store it separately and do not quantize it. - int outlier_index = -1; - float outlier_value = 0.0F; - - for (int l = 0; l < QK; l++) { - const float v = x[l]; - - if (fabsf(v) > fabsf(outlier_value)) { - outlier_index = l; - outlier_value = v; - } - } - - block->outlier_index = outlier_index; - block->outlier_value = GGML_RWKV_COMPUTE_FP32_TO_FP16(outlier_value); - - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int l = 0; l < QK; l++) { - if (l == outlier_index) { - // Ignore outlier when computing range. - continue; - } - - const float v = x[l]; - if (v < min) min = v; - if (v > max) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0F / d : 0.0F; - - block->d = GGML_RWKV_COMPUTE_FP32_TO_FP16(d); - block->m = GGML_RWKV_COMPUTE_FP32_TO_FP16(min); - - uint8_t pp[QK / 2]; - - for (int l = 0; l < QK; l += 2) { - float v0 = (x[l + 0] - min) * id; - float v1 = (x[l + 1] - min) * id; - - // Write some garbage but valid index for the outlier. - if (l + 0 == outlier_index) v0 = 0.0; - if (l + 1 == outlier_index) v1 = 0.0; - - const uint8_t vi0 = roundf(v0); - const uint8_t vi1 = roundf(v1); - - assert(vi0 >= 0 && vi0 < 16); - assert(vi1 >= 0 && vi1 < 16); - - pp[l/2] = vi0 | (vi1 << 4); - } - - memcpy(block->qs, pp, sizeof(pp)); -} - -static inline void dequantize_row_q4_1_o_reference_single_block(block_q4_1_o * restrict block, float * restrict y) { - const float d = ggml_rwkv_half_to_float_reference(block->d); - const float m = ggml_rwkv_half_to_float_reference(block->m); - - const uint8_t * restrict pp = block->qs; - - for (int l = 0; l < QK; l += 2) { - const uint8_t vi = pp[l / 2]; - - const int8_t vi0 = vi & 0xF; - const int8_t vi1 = vi >> 4; - - const float v0 = vi0 * d + m; - const float v1 = vi1 * d + m; - - y[l + 0] = v0; - y[l + 1] = v1; - - assert(!isnan(y[l + 0])); - assert(!isnan(y[l + 1])); - } - - // Restore the outlier - y[block->outlier_index] = ggml_rwkv_half_to_float_reference(block->outlier_value); -} - -static void quantize_row_q4_1_o_reference(const float * restrict x, void * restrict vy, int k) { - assert(k % QK == 0); - const int nb = k / QK; - - block_q4_1_o * restrict y = vy; - - for (int i = 0; i < nb; i++) { - quantize_row_q4_1_o_reference_single_block(x + i * QK, y + i); - } -} - -static void quantize_row_q4_1_o(const float * restrict x, void * restrict vy, int k) { - quantize_row_q4_1_o_reference(x, vy, k); -} - -static void dequantize_row_q4_1_o(const void * restrict vx, float * restrict y, int k) { - assert(k % QK == 0); - const int nb = k / QK; - - const block_q4_1_o * restrict x = vx; - -#if defined(__AVX2__) - for (int i = 0; i < nb; i++) { - const float x_d = ggml_rwkv_half_to_float_reference(x[i].d); - const float x_m = ggml_rwkv_half_to_float_reference(x[i].m); - - const __m256 d_v = _mm256_broadcast_ss(&x_d); - const __m256 d_m = _mm256_broadcast_ss(&x_m); - - const uint8_t * restrict pp = x[i].qs; - - for (int l = 0; l < QK; l += 32) { - // Load 32x4-bit integers into 32x8-bit integers - __m256i vx8 = bytesFromNibbles(pp+l/2); - - // Convert to 16-bit int - const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); - const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); - - // Convert to 32-bit int -> float 32 - const __m256 vf[4] = { - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) - }; - - // Scale, add m and store - for (int j = 0; j < 4; j++) { - const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); - _mm256_storeu_ps(y + i * QK + l + j*8, result); - } - } - - // Restore the outlier - y[i * QK + x[i].outlier_index] = ggml_rwkv_half_to_float_reference(x[i].outlier_value); - } -// #elif defined(__ARM_NEON) -// for (int i = 0; i < nb; i++) { -// const float x_d = ggml_rwkv_half_to_float_reference(x[i].d); -// const float x_m = ggml_rwkv_half_to_float_reference(x[i].m); - -// const float32x4_t vd = vdupq_n_f32(x_d); -// const float32x4_t vm = vdupq_n_f32(x_m); - -// const uint8_t * restrict pp = x[i].qs; - -// for (int l = 0; l < QK; l += 16) { -// // Load 16x4-bit integers into 8x8-bit integers -// const uint8x8_t v8 = vld1_u8(pp + l/2); - -// // Expand 4-bit qs to 8-bit bytes -// const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); -// const uint8x8_t v1 = vshr_n_u8(v8, 4); - -// // Interleave and combine -// const uint8x8_t vx_0 = vzip1_u8(v0, v1); -// const uint8x8_t vx_1 = vzip2_u8(v0, v1); - -// const uint8x16_t vq = vcombine_u8(vx_0, vx_1); - -// // convert to 2x uint16x8_t -// const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); -// const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); - -// // convert to 4x float32x4_t -// const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); -// const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); -// const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); -// const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); - -// // multiply by d and add m -// const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); -// const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); -// const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); -// const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); - -// // Store -// vst1q_f32(y + i*QK + l + 0, r0); -// vst1q_f32(y + i*QK + l + 4, r1); -// vst1q_f32(y + i*QK + l + 8, r2); -// vst1q_f32(y + i*QK + l + 12, r3); -// } - -// // Restore the outlier -// y[i * QK + x[i].outlier_index] = ggml_rwkv_half_to_float_reference(x[i].outlier_value); -// } -#else - for (int i = 0; i < nb; i++) { - dequantize_row_q4_1_o_reference_single_block(x + i, y + i * QK); - } -#endif -} - -// -// simd mappings -// - -// we define a common set of C macros which map to specific intrinsics based on the current architecture -// we then implement the fundamental computation operations below using only these macros -// adding support for new architectures requires to define the corresponding SIMD macros -// -// GGML_RWKV_F32_STEP / GGML_RWKV_F16_STEP -// number of elements to process in a single step -// -// GGML_RWKV_F32_EPR / GGML_RWKV_F16_EPR -// number of elements to fit in a single register -// - -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) - -#define GGML_RWKV_SIMD - -// F32 NEON - -#define GGML_RWKV_F32_STEP 16 -#define GGML_RWKV_F32_EPR 4 - -#define GGML_RWKV_F32x4 float32x4_t -#define GGML_RWKV_F32x4_ZERO vdupq_n_f32(0.0f) -#define GGML_RWKV_F32x4_SET1(x) vdupq_n_f32(x) -#define GGML_RWKV_F32x4_LOAD vld1q_f32 -#define GGML_RWKV_F32x4_STORE vst1q_f32 -#define GGML_RWKV_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) -#define GGML_RWKV_F32x4_ADD vaddq_f32 -#define GGML_RWKV_F32x4_MUL vmulq_f32 -#if defined(__ARM_FEATURE_QRDMX) - #define GGML_RWKV_F32x4_REDUCE_ONE(x) vaddvq_f32(x) -#else - #define GGML_RWKV_F32x4_REDUCE_ONE(x) \ - (vgetq_lane_f32(x, 0) + \ - vgetq_lane_f32(x, 1) + \ - vgetq_lane_f32(x, 2) + \ - vgetq_lane_f32(x, 3)) -#endif -#define GGML_RWKV_F32x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_RWKV_F32_ARR/2; ++i) { \ - x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/4; ++i) { \ - x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/8; ++i) { \ - x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \ - } \ - res = GGML_RWKV_F32x4_REDUCE_ONE(x[0]); \ -} - -#define GGML_RWKV_F32_VEC GGML_RWKV_F32x4 -#define GGML_RWKV_F32_VEC_ZERO GGML_RWKV_F32x4_ZERO -#define GGML_RWKV_F32_VEC_SET1 GGML_RWKV_F32x4_SET1 -#define GGML_RWKV_F32_VEC_LOAD GGML_RWKV_F32x4_LOAD -#define GGML_RWKV_F32_VEC_STORE GGML_RWKV_F32x4_STORE -#define GGML_RWKV_F32_VEC_FMA GGML_RWKV_F32x4_FMA -#define GGML_RWKV_F32_VEC_ADD GGML_RWKV_F32x4_ADD -#define GGML_RWKV_F32_VEC_MUL GGML_RWKV_F32x4_MUL -#define GGML_RWKV_F32_VEC_REDUCE GGML_RWKV_F32x4_REDUCE - -// F16 NEON - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - #define GGML_RWKV_F16_STEP 32 - #define GGML_RWKV_F16_EPR 8 - - #define GGML_RWKV_F16x8 float16x8_t - #define GGML_RWKV_F16x8_ZERO vdupq_n_f16(0.0f) - #define GGML_RWKV_F16x8_SET1(x) vdupq_n_f16(x) - #define GGML_RWKV_F16x8_LOAD vld1q_f16 - #define GGML_RWKV_F16x8_STORE vst1q_f16 - #define GGML_RWKV_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) - #define GGML_RWKV_F16x8_ADD vaddq_f16 - #define GGML_RWKV_F16x8_MUL vmulq_f16 - #define GGML_RWKV_F16x8_REDUCE(res, x) \ - { \ - for (int i = 0; i < GGML_RWKV_F16_ARR/2; ++i) { \ - x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_RWKV_F16_ARR/4; ++i) { \ - x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_RWKV_F16_ARR/8; ++i) { \ - x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \ - } \ - const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ - const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ - res = (ggml_rwkv_float) vaddvq_f32(vaddq_f32(t0, t1)); \ - } - - #define GGML_RWKV_F16_VEC GGML_RWKV_F16x8 - #define GGML_RWKV_F16_VEC_ZERO GGML_RWKV_F16x8_ZERO - #define GGML_RWKV_F16_VEC_SET1 GGML_RWKV_F16x8_SET1 - #define GGML_RWKV_F16_VEC_LOAD(p, i) GGML_RWKV_F16x8_LOAD(p) - #define GGML_RWKV_F16_VEC_STORE(p, r, i) GGML_RWKV_F16x8_STORE(p, r[i]) - #define GGML_RWKV_F16_VEC_FMA GGML_RWKV_F16x8_FMA - #define GGML_RWKV_F16_VEC_ADD GGML_RWKV_F16x8_ADD - #define GGML_RWKV_F16_VEC_MUL GGML_RWKV_F16x8_MUL - #define GGML_RWKV_F16_VEC_REDUCE GGML_RWKV_F16x8_REDUCE -#else - // if FP16 vector arithmetic is not supported, we use FP32 instead - // and take advantage of the vcvt_ functions to convert to/from FP16 - - #define GGML_RWKV_F16_STEP 16 - #define GGML_RWKV_F16_EPR 4 - - #define GGML_RWKV_F32Cx4 float32x4_t - #define GGML_RWKV_F32Cx4_ZERO vdupq_n_f32(0.0f) - #define GGML_RWKV_F32Cx4_SET1(x) vdupq_n_f32(x) - #define GGML_RWKV_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) - #define GGML_RWKV_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) - #define GGML_RWKV_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) - #define GGML_RWKV_F32Cx4_ADD vaddq_f32 - #define GGML_RWKV_F32Cx4_MUL vmulq_f32 - #define GGML_RWKV_F32Cx4_REDUCE GGML_RWKV_F32x4_REDUCE - - #define GGML_RWKV_F16_VEC GGML_RWKV_F32Cx4 - #define GGML_RWKV_F16_VEC_ZERO GGML_RWKV_F32Cx4_ZERO - #define GGML_RWKV_F16_VEC_SET1 GGML_RWKV_F32Cx4_SET1 - #define GGML_RWKV_F16_VEC_LOAD(p, i) GGML_RWKV_F32Cx4_LOAD(p) - #define GGML_RWKV_F16_VEC_STORE(p, r, i) GGML_RWKV_F32Cx4_STORE(p, r[i]) - #define GGML_RWKV_F16_VEC_FMA GGML_RWKV_F32Cx4_FMA - #define GGML_RWKV_F16_VEC_ADD GGML_RWKV_F32Cx4_ADD - #define GGML_RWKV_F16_VEC_MUL GGML_RWKV_F32Cx4_MUL - #define GGML_RWKV_F16_VEC_REDUCE GGML_RWKV_F32Cx4_REDUCE -#endif - -#elif defined(__AVX__) - -#define GGML_RWKV_SIMD - -// F32 AVX - -#define GGML_RWKV_F32_STEP 32 -#define GGML_RWKV_F32_EPR 8 - -#define GGML_RWKV_F32x8 __m256 -#define GGML_RWKV_F32x8_ZERO _mm256_setzero_ps() -#define GGML_RWKV_F32x8_SET1(x) _mm256_set1_ps(x) -#define GGML_RWKV_F32x8_LOAD _mm256_loadu_ps -#define GGML_RWKV_F32x8_STORE _mm256_storeu_ps -#if defined(__FMA__) - #define GGML_RWKV_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) -#else - #define GGML_RWKV_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) -#endif -#define GGML_RWKV_F32x8_ADD _mm256_add_ps -#define GGML_RWKV_F32x8_MUL _mm256_mul_ps -#define GGML_RWKV_F32x8_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_RWKV_F32_ARR/2; ++i) { \ - x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/4; ++i) { \ - x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/8; ++i) { \ - x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \ - } \ - const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ - _mm256_extractf128_ps(x[0], 1)); \ - const __m128 t1 = _mm_hadd_ps(t0, t0); \ - res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ -} -// TODO: is this optimal ? - -#define GGML_RWKV_F32_VEC GGML_RWKV_F32x8 -#define GGML_RWKV_F32_VEC_ZERO GGML_RWKV_F32x8_ZERO -#define GGML_RWKV_F32_VEC_SET1 GGML_RWKV_F32x8_SET1 -#define GGML_RWKV_F32_VEC_LOAD GGML_RWKV_F32x8_LOAD -#define GGML_RWKV_F32_VEC_STORE GGML_RWKV_F32x8_STORE -#define GGML_RWKV_F32_VEC_FMA GGML_RWKV_F32x8_FMA -#define GGML_RWKV_F32_VEC_ADD GGML_RWKV_F32x8_ADD -#define GGML_RWKV_F32_VEC_MUL GGML_RWKV_F32x8_MUL -#define GGML_RWKV_F32_VEC_REDUCE GGML_RWKV_F32x8_REDUCE - -// F16 AVX - -#define GGML_RWKV_F16_STEP 32 -#define GGML_RWKV_F16_EPR 8 - -// F16 arithmetic is not supported by AVX, so we use F32 instead - -#define GGML_RWKV_F32Cx8 __m256 -#define GGML_RWKV_F32Cx8_ZERO _mm256_setzero_ps() -#define GGML_RWKV_F32Cx8_SET1(x) _mm256_set1_ps(x) - -#if defined(__F16C__) -// the _mm256_cvt intrinsics require F16C -#define GGML_RWKV_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) -#define GGML_RWKV_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) -#else -static inline __m256 __avx_f32cx8_load(ggml_rwkv_fp16_t *x) { - float tmp[8]; - - for (int i = 0; i < 8; i++) - tmp[i] = GGML_RWKV_FP16_TO_FP32(x[i]); - - return _mm256_loadu_ps(tmp); -} -static inline void __avx_f32cx8_store(ggml_rwkv_fp16_t *x, __m256 y) { - float arr[8]; - - _mm256_storeu_ps(arr, y); - - for (int i = 0; i < 8; i++) - x[i] = GGML_RWKV_FP32_TO_FP16(arr[i]); -} -#define GGML_RWKV_F32Cx8_LOAD(x) __avx_f32cx8_load(x) -#define GGML_RWKV_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) -#endif - -#define GGML_RWKV_F32Cx8_FMA GGML_RWKV_F32x8_FMA -#define GGML_RWKV_F32Cx8_ADD _mm256_add_ps -#define GGML_RWKV_F32Cx8_MUL _mm256_mul_ps -#define GGML_RWKV_F32Cx8_REDUCE GGML_RWKV_F32x8_REDUCE - -#define GGML_RWKV_F16_VEC GGML_RWKV_F32Cx8 -#define GGML_RWKV_F16_VEC_ZERO GGML_RWKV_F32Cx8_ZERO -#define GGML_RWKV_F16_VEC_SET1 GGML_RWKV_F32Cx8_SET1 -#define GGML_RWKV_F16_VEC_LOAD(p, i) GGML_RWKV_F32Cx8_LOAD(p) -#define GGML_RWKV_F16_VEC_STORE(p, r, i) GGML_RWKV_F32Cx8_STORE(p, r[i]) -#define GGML_RWKV_F16_VEC_FMA GGML_RWKV_F32Cx8_FMA -#define GGML_RWKV_F16_VEC_ADD GGML_RWKV_F32Cx8_ADD -#define GGML_RWKV_F16_VEC_MUL GGML_RWKV_F32Cx8_MUL -#define GGML_RWKV_F16_VEC_REDUCE GGML_RWKV_F32Cx8_REDUCE - -#elif defined(__POWER9_VECTOR__) - -#define GGML_RWKV_SIMD - -// F32 POWER9 - -#define GGML_RWKV_F32_STEP 32 -#define GGML_RWKV_F32_EPR 4 - -#define GGML_RWKV_F32x4 vector float -#define GGML_RWKV_F32x4_ZERO 0.0f -#define GGML_RWKV_F32x4_SET1 vec_splats -#define GGML_RWKV_F32x4_LOAD(p) vec_xl(0, p) -#define GGML_RWKV_F32x4_STORE(p, r) vec_xst(r, 0, p) -#define GGML_RWKV_F32x4_FMA(a, b, c) vec_madd(b, c, a) -#define GGML_RWKV_F32x4_ADD vec_add -#define GGML_RWKV_F32x4_MUL vec_mul -#define GGML_RWKV_F32x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_RWKV_F32_ARR/2; ++i) { \ - x[2*i] = vec_add(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/4; ++i) { \ - x[4*i] = vec_add(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/8; ++i) { \ - x[8*i] = vec_add(x[8*i], x[8*i+4]); \ - } \ - res = vec_extract(x[0], 0) + \ - vec_extract(x[0], 1) + \ - vec_extract(x[0], 2) + \ - vec_extract(x[0], 3); \ -} - -#define GGML_RWKV_F32_VEC GGML_RWKV_F32x4 -#define GGML_RWKV_F32_VEC_ZERO GGML_RWKV_F32x4_ZERO -#define GGML_RWKV_F32_VEC_SET1 GGML_RWKV_F32x4_SET1 -#define GGML_RWKV_F32_VEC_LOAD GGML_RWKV_F32x4_LOAD -#define GGML_RWKV_F32_VEC_STORE GGML_RWKV_F32x4_STORE -#define GGML_RWKV_F32_VEC_FMA GGML_RWKV_F32x4_FMA -#define GGML_RWKV_F32_VEC_ADD GGML_RWKV_F32x4_ADD -#define GGML_RWKV_F32_VEC_MUL GGML_RWKV_F32x4_MUL -#define GGML_RWKV_F32_VEC_REDUCE GGML_RWKV_F32x4_REDUCE - -// F16 POWER9 -#define GGML_RWKV_F16_STEP GGML_RWKV_F32_STEP -#define GGML_RWKV_F16_EPR GGML_RWKV_F32_EPR -#define GGML_RWKV_F16_VEC GGML_RWKV_F32x4 -#define GGML_RWKV_F16_VEC_ZERO GGML_RWKV_F32x4_ZERO -#define GGML_RWKV_F16_VEC_SET1 GGML_RWKV_F32x4_SET1 -#define GGML_RWKV_F16_VEC_FMA GGML_RWKV_F32x4_FMA -#define GGML_RWKV_F16_VEC_REDUCE GGML_RWKV_F32x4_REDUCE -// Use vec_xl, not vec_ld, in case the load address is not aligned. -#define GGML_RWKV_F16_VEC_LOAD(p, i) (i & 0x1) ? \ - vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_RWKV_F16_EPR)) : \ - vec_extract_fp32_from_shortl(vec_xl(0, p)) -#define GGML_RWKV_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] -#define GGML_RWKV_F16_VEC_STORE(p, r, i) \ - if (i & 0x1) \ - vec_xst(vec_pack_to_short_fp32(r[i - GGML_RWKV_ENDIAN_BYTE(1)], \ - r[i - GGML_RWKV_ENDIAN_BYTE(0)]), \ - 0, p - GGML_RWKV_F16_EPR) - -#elif defined(__wasm_simd128__) - -#define GGML_RWKV_SIMD - -// F32 WASM - -#define GGML_RWKV_F32_STEP 16 -#define GGML_RWKV_F32_EPR 4 - -#define GGML_RWKV_F32x4 v128_t -#define GGML_RWKV_F32x4_ZERO wasm_f32x4_splat(0.0f) -#define GGML_RWKV_F32x4_SET1(x) wasm_f32x4_splat(x) -#define GGML_RWKV_F32x4_LOAD wasm_v128_load -#define GGML_RWKV_F32x4_STORE wasm_v128_store -#define GGML_RWKV_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) -#define GGML_RWKV_F32x4_ADD wasm_f32x4_add -#define GGML_RWKV_F32x4_MUL wasm_f32x4_mul -#define GGML_RWKV_F32x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_RWKV_F32_ARR/2; ++i) { \ - x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/4; ++i) { \ - x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/8; ++i) { \ - x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define GGML_RWKV_F32_VEC GGML_RWKV_F32x4 -#define GGML_RWKV_F32_VEC_ZERO GGML_RWKV_F32x4_ZERO -#define GGML_RWKV_F32_VEC_SET1 GGML_RWKV_F32x4_SET1 -#define GGML_RWKV_F32_VEC_LOAD GGML_RWKV_F32x4_LOAD -#define GGML_RWKV_F32_VEC_STORE GGML_RWKV_F32x4_STORE -#define GGML_RWKV_F32_VEC_FMA GGML_RWKV_F32x4_FMA -#define GGML_RWKV_F32_VEC_ADD GGML_RWKV_F32x4_ADD -#define GGML_RWKV_F32_VEC_MUL GGML_RWKV_F32x4_MUL -#define GGML_RWKV_F32_VEC_REDUCE GGML_RWKV_F32x4_REDUCE - -// F16 WASM - -#define GGML_RWKV_F16_STEP 16 -#define GGML_RWKV_F16_EPR 4 - -inline static v128_t __wasm_f16x4_load(const ggml_rwkv_fp16_t * p) { - float tmp[4]; - - tmp[0] = GGML_RWKV_FP16_TO_FP32(p[0]); - tmp[1] = GGML_RWKV_FP16_TO_FP32(p[1]); - tmp[2] = GGML_RWKV_FP16_TO_FP32(p[2]); - tmp[3] = GGML_RWKV_FP16_TO_FP32(p[3]); - - return wasm_v128_load(tmp); -} - -inline static void __wasm_f16x4_store(ggml_rwkv_fp16_t * p, v128_t x) { - float tmp[4]; - - wasm_v128_store(tmp, x); - - p[0] = GGML_RWKV_FP32_TO_FP16(tmp[0]); - p[1] = GGML_RWKV_FP32_TO_FP16(tmp[1]); - p[2] = GGML_RWKV_FP32_TO_FP16(tmp[2]); - p[3] = GGML_RWKV_FP32_TO_FP16(tmp[3]); -} - -#define GGML_RWKV_F16x4 v128_t -#define GGML_RWKV_F16x4_ZERO wasm_f32x4_splat(0.0f) -#define GGML_RWKV_F16x4_SET1(x) wasm_f32x4_splat(x) -#define GGML_RWKV_F16x4_LOAD(x) __wasm_f16x4_load(x) -#define GGML_RWKV_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) -#define GGML_RWKV_F16x4_FMA GGML_RWKV_F32x4_FMA -#define GGML_RWKV_F16x4_ADD wasm_f32x4_add -#define GGML_RWKV_F16x4_MUL wasm_f32x4_mul -#define GGML_RWKV_F16x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_RWKV_F16_ARR/2; ++i) { \ - x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_RWKV_F16_ARR/4; ++i) { \ - x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_RWKV_F16_ARR/8; ++i) { \ - x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define GGML_RWKV_F16_VEC GGML_RWKV_F16x4 -#define GGML_RWKV_F16_VEC_ZERO GGML_RWKV_F16x4_ZERO -#define GGML_RWKV_F16_VEC_SET1 GGML_RWKV_F16x4_SET1 -#define GGML_RWKV_F16_VEC_LOAD(p, i) GGML_RWKV_F16x4_LOAD(p) -#define GGML_RWKV_F16_VEC_STORE(p, r, i) GGML_RWKV_F16x4_STORE(p, r[i]) -#define GGML_RWKV_F16_VEC_FMA GGML_RWKV_F16x4_FMA -#define GGML_RWKV_F16_VEC_ADD GGML_RWKV_F16x4_ADD -#define GGML_RWKV_F16_VEC_MUL GGML_RWKV_F16x4_MUL -#define GGML_RWKV_F16_VEC_REDUCE GGML_RWKV_F16x4_REDUCE - -#elif defined(__SSE3__) - -#define GGML_RWKV_SIMD - -// F32 SSE - -#define GGML_RWKV_F32_STEP 32 -#define GGML_RWKV_F32_EPR 4 - -#define GGML_RWKV_F32x4 __m128 -#define GGML_RWKV_F32x4_ZERO _mm_setzero_ps() -#define GGML_RWKV_F32x4_SET1(x) _mm_set1_ps(x) -#define GGML_RWKV_F32x4_LOAD _mm_loadu_ps -#define GGML_RWKV_F32x4_STORE _mm_storeu_ps -#if defined(__FMA__) - // TODO: Does this work? - #define GGML_RWKV_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) -#else - #define GGML_RWKV_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) -#endif -#define GGML_RWKV_F32x4_ADD _mm_add_ps -#define GGML_RWKV_F32x4_MUL _mm_mul_ps -#define GGML_RWKV_F32x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_RWKV_F32_ARR/2; ++i) { \ - x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/4; ++i) { \ - x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_RWKV_F32_ARR/8; ++i) { \ - x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]); \ - } \ - const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ - res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ -} -// TODO: is this optimal ? - -#define GGML_RWKV_F32_VEC GGML_RWKV_F32x4 -#define GGML_RWKV_F32_VEC_ZERO GGML_RWKV_F32x4_ZERO -#define GGML_RWKV_F32_VEC_SET1 GGML_RWKV_F32x4_SET1 -#define GGML_RWKV_F32_VEC_LOAD GGML_RWKV_F32x4_LOAD -#define GGML_RWKV_F32_VEC_STORE GGML_RWKV_F32x4_STORE -#define GGML_RWKV_F32_VEC_FMA GGML_RWKV_F32x4_FMA -#define GGML_RWKV_F32_VEC_ADD GGML_RWKV_F32x4_ADD -#define GGML_RWKV_F32_VEC_MUL GGML_RWKV_F32x4_MUL -#define GGML_RWKV_F32_VEC_REDUCE GGML_RWKV_F32x4_REDUCE - -// F16 SSE - -#define GGML_RWKV_F16_STEP 32 -#define GGML_RWKV_F16_EPR 4 - -static inline __m128 __sse_f16x4_load(ggml_rwkv_fp16_t *x) { - float tmp[4]; - - tmp[0] = GGML_RWKV_FP16_TO_FP32(x[0]); - tmp[1] = GGML_RWKV_FP16_TO_FP32(x[1]); - tmp[2] = GGML_RWKV_FP16_TO_FP32(x[2]); - tmp[3] = GGML_RWKV_FP16_TO_FP32(x[3]); - - return _mm_loadu_ps(tmp); -} - -static inline void __sse_f16x4_store(ggml_rwkv_fp16_t *x, __m128 y) { - float arr[4]; - - _mm_storeu_ps(arr, y); - - x[0] = GGML_RWKV_FP32_TO_FP16(arr[0]); - x[1] = GGML_RWKV_FP32_TO_FP16(arr[1]); - x[2] = GGML_RWKV_FP32_TO_FP16(arr[2]); - x[3] = GGML_RWKV_FP32_TO_FP16(arr[3]); -} - -#define GGML_RWKV_F32Cx4 __m128 -#define GGML_RWKV_F32Cx4_ZERO _mm_setzero_ps() -#define GGML_RWKV_F32Cx4_SET1(x) _mm_set1_ps(x) -#define GGML_RWKV_F32Cx4_LOAD(x) __sse_f16x4_load(x) -#define GGML_RWKV_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) -#define GGML_RWKV_F32Cx4_FMA GGML_RWKV_F32x4_FMA -#define GGML_RWKV_F32Cx4_ADD _mm_add_ps -#define GGML_RWKV_F32Cx4_MUL _mm_mul_ps -#define GGML_RWKV_F32Cx4_REDUCE GGML_RWKV_F32x4_REDUCE - -#define GGML_RWKV_F16_VEC GGML_RWKV_F32Cx4 -#define GGML_RWKV_F16_VEC_ZERO GGML_RWKV_F32Cx4_ZERO -#define GGML_RWKV_F16_VEC_SET1 GGML_RWKV_F32Cx4_SET1 -#define GGML_RWKV_F16_VEC_LOAD(p, i) GGML_RWKV_F32Cx4_LOAD(p) -#define GGML_RWKV_F16_VEC_STORE(p, r, i) GGML_RWKV_F32Cx4_STORE(p, r[i]) -#define GGML_RWKV_F16_VEC_FMA GGML_RWKV_F32Cx4_FMA -#define GGML_RWKV_F16_VEC_ADD GGML_RWKV_F32Cx4_ADD -#define GGML_RWKV_F16_VEC_MUL GGML_RWKV_F32Cx4_MUL -#define GGML_RWKV_F16_VEC_REDUCE GGML_RWKV_F32Cx4_REDUCE - -#endif - -// GGML_RWKV_F32_ARR / GGML_RWKV_F16_ARR -// number of registers to use per step -#ifdef GGML_RWKV_SIMD -#define GGML_RWKV_F32_ARR (GGML_RWKV_F32_STEP/GGML_RWKV_F32_EPR) -#define GGML_RWKV_F16_ARR (GGML_RWKV_F16_STEP/GGML_RWKV_F16_EPR) -#endif - -// -// fundamental operations -// - -inline static void ggml_rwkv_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_rwkv_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_rwkv_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_rwkv_vec_set_f16(const int n, ggml_rwkv_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_rwkv_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } -inline static void ggml_rwkv_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } -inline static void ggml_rwkv_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } -inline static void ggml_rwkv_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } -inline static void ggml_rwkv_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_rwkv_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } -inline static void ggml_rwkv_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } -inline static void ggml_rwkv_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } -inline static void ggml_rwkv_vec_1_minus_x_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1 - x[i]; } -inline static void ggml_rwkv_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } -inline static void ggml_rwkv_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } -inline static void ggml_rwkv_vec_element_wise_max_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = fmaxf(x[i], y[i]); } - -inline static void ggml_rwkv_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { -#ifdef GGML_RWKV_SIMD - float sumf = 0.0f; - const int np = (n & ~(GGML_RWKV_F32_STEP - 1)); - - GGML_RWKV_F32_VEC sum[GGML_RWKV_F32_ARR] = { GGML_RWKV_F32_VEC_ZERO }; - - GGML_RWKV_F32_VEC ax[GGML_RWKV_F32_ARR]; - GGML_RWKV_F32_VEC ay[GGML_RWKV_F32_ARR]; - - for (int i = 0; i < np; i += GGML_RWKV_F32_STEP) { - for (int j = 0; j < GGML_RWKV_F32_ARR; j++) { - ax[j] = GGML_RWKV_F32_VEC_LOAD(x + i + j*GGML_RWKV_F32_EPR); - ay[j] = GGML_RWKV_F32_VEC_LOAD(y + i + j*GGML_RWKV_F32_EPR); - - sum[j] = GGML_RWKV_F32_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_RWKV_F32_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } -#else - // scalar - ggml_rwkv_float sumf = 0.0; - for (int i = 0; i < n; ++i) { - sumf += (ggml_rwkv_float)(x[i]*y[i]); - } -#endif - - *s = sumf; -} - -#if __AVX512F__ && QK == 32 -static inline __m512 dot_q4_0_oneblock_avx512( - __m512 acc, - const block_q4_0 * restrict x, - const block_q4_0 * restrict y, - int i -) { - // Compute combined scale for the block - __m512 d = _mm512_set1_ps( x[i].d * y[i].d ); - - __m256i bx = bytesFromNibbles( x[i].qs ); - __m256i by = bytesFromNibbles( y[i].qs ); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - bx = _mm256_sub_epi8( bx, off ); - by = _mm256_sub_epi8( by, off ); - - // Sign-extend 16 signed bytes into int16_t - __m512i x32 = _mm512_cvtepi8_epi16( bx ); - __m512i y32 = _mm512_cvtepi8_epi16( by ); - // Compute products of int16_t integers, add pairwise - __m512i i64 = _mm512_madd_epi16( x32, y32 ); - - // Convert int32_t to float - __m512 p = _mm512_cvtepi32_ps( i64 ); - // Apply the scale, and accumulate - return _mm512_fmadd_ps( d, p, acc ); -} -#endif - -inline static void ggml_rwkv_vec_dot_f16(const int n, float * restrict s, ggml_rwkv_fp16_t * restrict x, ggml_rwkv_fp16_t * restrict y) { - ggml_rwkv_float sumf = 0.0; - -#if defined(GGML_RWKV_SIMD) - const int np = (n & ~(GGML_RWKV_F16_STEP - 1)); - - GGML_RWKV_F16_VEC sum[GGML_RWKV_F16_ARR] = { GGML_RWKV_F16_VEC_ZERO }; - - GGML_RWKV_F16_VEC ax[GGML_RWKV_F16_ARR]; - GGML_RWKV_F16_VEC ay[GGML_RWKV_F16_ARR]; - - for (int i = 0; i < np; i += GGML_RWKV_F16_STEP) { - for (int j = 0; j < GGML_RWKV_F16_ARR; j++) { - ax[j] = GGML_RWKV_F16_VEC_LOAD(x + i + j*GGML_RWKV_F16_EPR, j); - ay[j] = GGML_RWKV_F16_VEC_LOAD(y + i + j*GGML_RWKV_F16_EPR, j); - - sum[j] = GGML_RWKV_F16_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_RWKV_F16_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += (ggml_rwkv_float)(GGML_RWKV_FP16_TO_FP32(x[i])*GGML_RWKV_FP16_TO_FP32(y[i])); - } -#else - for (int i = 0; i < n; ++i) { - sumf += (ggml_rwkv_float)(GGML_RWKV_FP16_TO_FP32(x[i])*GGML_RWKV_FP16_TO_FP32(y[i])); - } -#endif - - *s = sumf; -} - -static void ggml_rwkv_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK; - - assert(n % QK == 0); - assert(nb % 2 == 0); - - const block_q4_0 * restrict x = vx; - const block_q4_0 * restrict y = vy; - - ggml_rwkv_float sumf = 0.0; - -#if defined(__ARM_NEON) - float sum0 = 0.0f; - float sum1 = 0.0f; - - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict y0 = &y[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q4_0 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0xf); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v1_0 = vld1q_u8(y0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - const uint8x16_t v1_1 = vld1q_u8(y1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); - const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b)); - - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4)); - - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b)); - const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b)); - - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4)); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); - - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); - - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); - - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); - -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int16x8_t - int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls); - int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls); - - p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs); - p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs); - - // scalar -#if defined(__ARM_FEATURE_QRDMX) - sum0 += x0->d * y0->d * vaddvq_s32(p_0); - sum1 += x1->d * y1->d * vaddvq_s32(p_1); -#else - sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3)); - sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3)); -#endif -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); - - const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); - const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); - - const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); - const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); - - const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); - const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); - - // scalar -#if defined(__ARM_FEATURE_QRDMX) - sum0 += x0->d * y0->d * vaddvq_s16(p_0); - sum1 += x1->d * y1->d * vaddvq_s16(p_1); -#else - sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7)); - sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7)); -#endif -#endif - } - - sumf = (ggml_rwkv_float)(sum0 + sum1); -#elif defined(__AVX512F__) - // Initialize accumulator with zeros - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - - const int superblock_size = 8; - const int superblock_count = nb / superblock_size; - const int remainder = nb % superblock_size; - - for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) { - int i = superblock_ix * superblock_size; - - acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 ); - acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 ); - acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 ); - acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 ); - acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 ); - acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 ); - acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 ); - acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 ); - } - - // Remainders - for (int i = superblock_count * superblock_size; i < nb; ++i) { - acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i ); - } - - // Horizontal sum of all lanes of the accumulator - sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 ); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - __m256i bx = bytesFromNibbles( x[i].qs ); - __m256i by = bytesFromNibbles( y[i].qs ); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - bx = _mm256_sub_epi8( bx, off ); - by = _mm256_sub_epi8( by, off ); - - // Sign-extend first 16 signed bytes into int16_t - __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); - __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - // Compute products of int16_t integers, add pairwise - __m256i i32 = _mm256_madd_epi16( x16, y16 ); - - // Sign-extend last 16 signed bytes into int16_t vectors - x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); - y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - // Accumulate products of int16_t integers - i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) ); - - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps( i32 ); - // Apply the scale, and accumulate - acc = _mm256_fmadd_ps( d, p, acc ); - } - - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); -#elif defined(__wasm_simd128__) - // wasm simd - float sum0 = 0.0f; - float sum1 = 0.0f; - - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &px[i + 0]; - const block_q4_0 * restrict y0 = &py[i + 0]; - const block_q4_0 * restrict x1 = &px[i + 1]; - const block_q4_0 * restrict y1 = &py[i + 1]; - - const v128_t m4b = wasm_u8x16_splat(0xf); - const v128_t s8b = wasm_i8x16_splat(0x8); - - const v128_t v0_0 = wasm_v128_load(x0.qs); - const v128_t v0_1 = wasm_v128_load(y0.qs); - const v128_t v1_0 = wasm_v128_load(x1.qs); - const v128_t v1_1 = wasm_v128_load(y1.qs); - - // 4-bit -> 8-bit - const v128_t v0_0l = wasm_v128_and(v0_0, m4b); - const v128_t v1_0l = wasm_v128_and(v1_0, m4b); - - const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); - const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4); - - const v128_t v0_1l = wasm_v128_and(v0_1, m4b); - const v128_t v1_1l = wasm_v128_and(v1_1, m4b); - - const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); - const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4); - - // sub 8 - const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); - const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b); - - const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); - const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b); - - const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); - const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b); - - const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); - const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b); - - // dot product into int16x8_t - const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls)); - const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls)); - - const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs)); - const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs)); - - const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls)); - const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls)); - - const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs)); - const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs)); - - const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h); - const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h); - - const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h); - const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h); - - const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0); - const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1); - - sum0 += x0->d * y0->d * ( - wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) + - wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) + - wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) + - wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7)); - sum1 += x1->d * y1->d * ( - wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) + - wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) + - wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) + - wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7)); - } - - sumf = sum0 + sum1; -#else - // scalar - for (int i = 0; i < nb; i++) { - const float d0 = x[i].d; - const float d1 = y[i].d; - - const uint8_t * restrict p0 = x[i].qs; - const uint8_t * restrict p1 = y[i].qs; - - for (int j = 0; j < QK/2; j++) { - const uint8_t v0 = p0[j]; - const uint8_t v1 = p1[j]; - - const float f0 = d0*((int8_t) (v0 & 0xf) - 8); - const float f1 = d0*((int8_t) (v0 >> 4) - 8); - - const float f2 = d1*((int8_t) (v1 & 0xf) - 8); - const float f3 = d1*((int8_t) (v1 >> 4) - 8); - - sumf += f0*f2 + f1*f3; - } - } -#endif - - *s = sumf; -} - -static void ggml_rwkv_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK; - - const block_q4_1 * restrict x = vx; - const block_q4_1 * restrict y = vy; - - float sumf = 0.0; - -#if defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - // Accumulator for constant offsets - float acc_offset = 0.0f; - - // Main loop - for (int i = 0; i < nb; ++i) { - const float * d0 = &x[i].d; - const float * d1 = &y[i].d; - - const float * m0 = &x[i].m; - const float * m1 = &y[i].m; - - const __m256 d0v = _mm256_broadcast_ss( d0 ); - const __m256 d1v = _mm256_broadcast_ss( d1 ); - const __m256 m0v = _mm256_broadcast_ss( m0 ); - const __m256 m1v = _mm256_broadcast_ss( m1 ); - - // Compute combined scale for the block - const __m256 scale_01 = _mm256_mul_ps( d0v, d1v ); - - // Compute cross scales for the block - const __m256 scale_0 = _mm256_mul_ps( d0v, m1v ); - const __m256 scale_1 = _mm256_mul_ps( m0v, d1v ); - const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - __m256i bx = bytesFromNibbles( x[i].qs ); - __m256i by = bytesFromNibbles( y[i].qs ); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. - - // Sign-extend first 16 signed bytes into int16_t - __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); - __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - // Compute products of int16_t integers, add pairwise - __m256i i32 = _mm256_madd_epi16( x16, y16 ); - - // Sign-extend last 16 signed bytes into int16_t vectors - __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); - __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - // Accumulate products of int16_t integers - i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) ); - - // compute sums of unsigned bytes in bx, by in blocks of 8. - // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000, - // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400. - // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ] - __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() ); - __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() ); - __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) ); - __m256 sums = _mm256_cvtepi32_ps( sumsi ); - - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps( i32 ); - // Apply the scale, and accumulate - // acc += d0*d1*x*y + d0*m1*x + d1*m0*y - acc = _mm256_fmadd_ps( scale_01, p, acc ); - acc = _mm256_fmadd_ps( cross_scales, sums, acc ); - // acc_offset += m0*m1 (for each entry in the block) - acc_offset += (*m0)*(*m1); - } - - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ) + acc_offset * QK; -#elif defined(__ARM_NEON) - float sum00 = 0.0f; - float sum01 = 0.0f; - float sum10 = 0.0f; - float sum11 = 0.0f; - - for (int i = 0; i < nb; ++i) { - const block_q4_1 * restrict x0 = &x[i + 0]; - const block_q4_1 * restrict y0 = &y[i + 0]; - - const uint8x16_t m4b = vdupq_n_u8(0xf); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v1_0 = vld1q_u8(y0->qs); - - // and with 0xf - const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); - const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); - - const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); - const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); - - // dot product into uint16x8_t - const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); - const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); - - const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); - const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); - - const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h); - const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h); - - sum00 += x0->m*y0->m; - sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); - sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); - sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0)); - } - - sumf = QK*sum00 + sum01 + sum10 + sum11; -#else - // scalar - for (int i = 0; i < nb; i++) { - const float d0 = x[i].d; - const float d1 = y[i].d; - - const float m0 = x[i].m; - const float m1 = y[i].m; - - const uint8_t * restrict p0 = x[i].qs; - const uint8_t * restrict p1 = y[i].qs; - - for (int j = 0; j < QK/2; j++) { - const uint8_t v0 = p0[j]; - const uint8_t v1 = p1[j]; - - const float f0 = d0*(v0 & 0xf) + m0; - const float f1 = d0*(v0 >> 4) + m0; - - const float f2 = d1*(v1 & 0xf) + m1; - const float f3 = d1*(v1 >> 4) + m1; - - sumf += f0*f2 + f1*f3; - } - } -#endif - - *s = sumf; -} - -// compute GGML_RWKV_VEC_DOT_UNROLL dot products at once -// xs - x row stride in bytes -inline static void ggml_rwkv_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_rwkv_fp16_t * restrict y) { - ggml_rwkv_float sumf[GGML_RWKV_VEC_DOT_UNROLL] = { 0.0 }; - - ggml_rwkv_fp16_t * restrict x[GGML_RWKV_VEC_DOT_UNROLL]; - - for (int i = 0; i < GGML_RWKV_VEC_DOT_UNROLL; ++i) { - x[i] = (ggml_rwkv_fp16_t *) ((char *) xv + i*xs); - } - -#if defined(GGML_RWKV_SIMD) - const int np = (n & ~(GGML_RWKV_F16_STEP - 1)); - - GGML_RWKV_F16_VEC sum[GGML_RWKV_VEC_DOT_UNROLL][GGML_RWKV_F16_ARR] = { { GGML_RWKV_F16_VEC_ZERO } }; - - GGML_RWKV_F16_VEC ax[GGML_RWKV_F16_ARR]; - GGML_RWKV_F16_VEC ay[GGML_RWKV_F16_ARR]; - - for (int i = 0; i < np; i += GGML_RWKV_F16_STEP) { - for (int j = 0; j < GGML_RWKV_F16_ARR; j++) { - ay[j] = GGML_RWKV_F16_VEC_LOAD(y + i + j*GGML_RWKV_F16_EPR, j); - - for (int k = 0; k < GGML_RWKV_VEC_DOT_UNROLL; ++k) { - ax[j] = GGML_RWKV_F16_VEC_LOAD(x[k] + i + j*GGML_RWKV_F16_EPR, j); - - sum[k][j] = GGML_RWKV_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); - } - } - } - - // reduce sum0..sum3 to sum0 - for (int k = 0; k < GGML_RWKV_VEC_DOT_UNROLL; ++k) { - GGML_RWKV_F16_VEC_REDUCE(sumf[k], sum[k]); - } - - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_RWKV_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_rwkv_float)(GGML_RWKV_FP16_TO_FP32(x[j][i])*GGML_RWKV_FP16_TO_FP32(y[i])); - } - } -#else - for (int i = 0; i < n; ++i) { - for (int j = 0; j < GGML_RWKV_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_rwkv_float)(GGML_RWKV_FP16_TO_FP32(x[j][i])*GGML_RWKV_FP16_TO_FP32(y[i])); - } - } -#endif - - for (int i = 0; i < GGML_RWKV_VEC_DOT_UNROLL; ++i) { - s[i] = sumf[i]; - } -} - -inline static void ggml_rwkv_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { -#if defined(GGML_RWKV_SIMD) - const int np = (n & ~(GGML_RWKV_F32_STEP - 1)); - - GGML_RWKV_F32_VEC vx = GGML_RWKV_F32_VEC_SET1(v); - - GGML_RWKV_F32_VEC ax[GGML_RWKV_F32_ARR]; - GGML_RWKV_F32_VEC ay[GGML_RWKV_F32_ARR]; - - for (int i = 0; i < np; i += GGML_RWKV_F32_STEP) { - for (int j = 0; j < GGML_RWKV_F32_ARR; j++) { - ax[j] = GGML_RWKV_F32_VEC_LOAD(x + i + j*GGML_RWKV_F32_EPR); - ay[j] = GGML_RWKV_F32_VEC_LOAD(y + i + j*GGML_RWKV_F32_EPR); - ay[j] = GGML_RWKV_F32_VEC_FMA(ay[j], ax[j], vx); - - GGML_RWKV_F32_VEC_STORE(y + i + j*GGML_RWKV_F32_EPR, ay[j]); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] += x[i]*v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] += x[i]*v; - } -#endif -} - -//inline static void ggml_rwkv_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } -inline static void ggml_rwkv_vec_scale_f32(const int n, float * y, const float v) { -#if defined(GGML_RWKV_SIMD) - const int np = (n & ~(GGML_RWKV_F32_STEP - 1)); - - GGML_RWKV_F32_VEC vx = GGML_RWKV_F32_VEC_SET1(v); - - GGML_RWKV_F32_VEC ay[GGML_RWKV_F32_ARR]; - - for (int i = 0; i < np; i += GGML_RWKV_F32_STEP) { - for (int j = 0; j < GGML_RWKV_F32_ARR; j++) { - ay[j] = GGML_RWKV_F32_VEC_LOAD(y + i + j*GGML_RWKV_F32_EPR); - ay[j] = GGML_RWKV_F32_VEC_MUL(ay[j], vx); - - GGML_RWKV_F32_VEC_STORE(y + i + j*GGML_RWKV_F32_EPR, ay[j]); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] *= v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] *= v; - } -#endif -} - -inline static void ggml_rwkv_vec_norm_f32 (const int n, float * s, const float * x) { ggml_rwkv_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } -inline static void ggml_rwkv_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } -inline static void ggml_rwkv_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } -inline static void ggml_rwkv_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } -inline static void ggml_rwkv_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } -inline static void ggml_rwkv_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } -inline static void ggml_rwkv_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } - -static const float GELU_COEF_A = 0.044715f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -inline static float ggml_rwkv_gelu_f32(float x) { - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -inline static void ggml_rwkv_vec_gelu_f16(const int n, ggml_rwkv_fp16_t * y, const ggml_rwkv_fp16_t * x) { - const uint16_t * i16 = (const uint16_t *) x; - for (int i = 0; i < n; ++i) { - y[i] = table_gelu_f16[i16[i]]; - } -} - -#ifdef GGML_RWKV_GELU_FP16 -inline static void ggml_rwkv_vec_gelu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - ggml_rwkv_fp16_t fp16 = GGML_RWKV_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_RWKV_FP16_TO_FP32(table_gelu_f16[t]); - } -} -#else -inline static void ggml_rwkv_vec_gelu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_rwkv_gelu_f32(x[i]); - } -} -#endif - -// Sigmoid function -inline static float ggml_rwkv_sigmoid_f32(float x) { - return 1.0F / (1.0F + expf(-x)); -} - -inline static void ggml_rwkv_vec_sigmoid_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_rwkv_sigmoid_f32(x[i]); - } -} - -// Sigmoid Linear Unit (SiLU) function -inline static float ggml_rwkv_silu_f32(float x) { - return x/(1.0f + expf(-x)); -} - -inline static void ggml_rwkv_vec_silu_f16(const int n, ggml_rwkv_fp16_t * y, const ggml_rwkv_fp16_t * x) { - const uint16_t * i16 = (const uint16_t *) x; - for (int i = 0; i < n; ++i) { - y[i] = table_silu_f16[i16[i]]; - } -} - -#ifdef GGML_RWKV_SILU_FP16 -inline static void ggml_rwkv_vec_silu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - ggml_rwkv_fp16_t fp16 = GGML_RWKV_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_RWKV_FP16_TO_FP32(table_silu_f16[t]); - } -} -#else -inline static void ggml_rwkv_vec_silu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_rwkv_silu_f32(x[i]); - } -} -#endif - -inline static void ggml_rwkv_vec_sum_f32(const int n, float * s, const float * x) { -#ifndef GGML_RWKV_USE_ACCELERATE - ggml_rwkv_float sum = 0.0; - for (int i = 0; i < n; ++i) { - sum += (ggml_rwkv_float)x[i]; - } - *s = sum; -#else - vDSP_sve(x, 1, s, n); -#endif -} - -inline static void ggml_rwkv_vec_max_f32(const int n, float * s, const float * x) { -#ifndef GGML_RWKV_USE_ACCELERATE - float max = -INFINITY; - for (int i = 0; i < n; ++i) { - max = MAX(max, x[i]); - } - *s = max; -#else - vDSP_maxv(x, 1, s, n); -#endif -} - -inline static void ggml_rwkv_vec_norm_inv_f32(const int n, float * s, const float * x) { - ggml_rwkv_vec_norm_f32(n, s, x); - *s = 1.f/(*s); -} - -// -// logging -// - -#if (GGML_RWKV_DEBUG >= 1) -#define GGML_RWKV_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_RWKV_PRINT_DEBUG(...) -#endif - -#if (GGML_RWKV_DEBUG >= 5) -#define GGML_RWKV_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_RWKV_PRINT_DEBUG_5(...) -#endif - -#if (GGML_RWKV_DEBUG >= 10) -#define GGML_RWKV_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_RWKV_PRINT_DEBUG_10(...) -#endif - -#define GGML_RWKV_PRINT(...) printf(__VA_ARGS__) - -// -// data types -// - -static const int GGML_RWKV_BLCK_SIZE[GGML_RWKV_TYPE_COUNT] = { - QK, - QK, - QK, - 1, - 1, - 1, - 1, - 1, -}; - -static_assert(GGML_RWKV_TYPE_COUNT == 8, "GGML_RWKV_TYPE_COUNT != 8"); - -static const size_t GGML_RWKV_TYPE_SIZE[GGML_RWKV_TYPE_COUNT] = { - sizeof(block_q4_0), - sizeof(block_q4_1), - sizeof(block_q4_1_o), - sizeof(int8_t ), - sizeof(int16_t), - sizeof(int32_t), - sizeof(ggml_rwkv_fp16_t), - sizeof(float ), -}; - -// don't forget to update the array above when adding new types -static_assert(GGML_RWKV_TYPE_COUNT == 8, "GGML_RWKV_TYPE_COUNT != 8"); - -static const char * GGML_RWKV_OP_LABEL[GGML_RWKV_OP_COUNT] = { - "NONE", - - "DUP", - "ADD", - "SUB", - "MUL", - "DIV", - "SQR", - "SQRT", - "SUM", - "MEAN", - "REPEAT", - "ABS", - "SGN", - "NEG", - "EXP", - "1_MINUS_X", - "MAX", - - "STEP", - "RELU", - "GELU", - "SIGMOID", - "SILU", - "NORM", - "RMS_NORM", - - "MUL_MAT", - - "SCALE", - "CPY", - "RESHAPE", - "VIEW", - "PERMUTE", - "TRANSPOSE", - "GET_ROWS", - "DIAG_MASK_INF", - "SOFT_MAX", - "ROPE", - "CONV_1D_1S", - "CONV_1D_2S", - - "FLASH_ATTN", - "FLASH_FF", -}; - -static_assert(GGML_RWKV_OP_COUNT == 39, "GGML_RWKV_OP_COUNT != 39"); - -static const char * GGML_RWKV_OP_SYMBOL[GGML_RWKV_OP_COUNT] = { - "none", - - "x", - "x+y", - "x-y", - "x*y", - "x/y", - "x^2", - "√x", - "Σx", - "Σx/n", - "repeat(x)", - "abs(x)", - "sgn(x)", - "-x", - "e^x", - "1-x", - "max(x,y)", - - "step(x)", - "relu(x)", - "gelu(x)", - "sigmoid(x)", - "silu(x)", - "norm(x)", - "rms_norm(x)", - - "X*Y", - - "x*v", - "x-\\>y", - "reshape(x)", - "view(x)", - "permute(x)", - "transpose(x)", - "get_rows(x)", - "diag_mask_inf(x)", - "soft_max(x)", - "rope(x)", - "conv_1d_1s(x)", - "conv_1d_2s(x)", - - "flash_attn(x)", - "flash_ff(x)", -}; - -static_assert(GGML_RWKV_OP_COUNT == 39, "GGML_RWKV_OP_COUNT != 39"); - -// -// ggml object -// - -struct ggml_rwkv_object { - size_t offs; - size_t size; - - struct ggml_rwkv_object * next; - - char padding[8]; -}; - -static const size_t GGML_RWKV_OBJECT_SIZE = sizeof(struct ggml_rwkv_object); - -static_assert(sizeof(struct ggml_rwkv_object)%GGML_RWKV_MEM_ALIGN == 0, "ggml_rwkv_object size must be a multiple of GGML_RWKV_MEM_ALIGN"); -static_assert(sizeof(struct ggml_rwkv_tensor)%GGML_RWKV_MEM_ALIGN == 0, "ggml_rwkv_tensor size must be a multiple of GGML_RWKV_MEM_ALIGN"); - -// -// ggml context -// - -struct ggml_rwkv_context { - size_t mem_size; - void * mem_buffer; - bool mem_buffer_owned; - bool mem_buffer_mlocked; - - int n_objects; - - struct ggml_rwkv_object * objects_begin; - struct ggml_rwkv_object * objects_end; - - struct ggml_rwkv_scratch scratch; - struct ggml_rwkv_scratch scratch_save; -}; - -struct ggml_rwkv_context_container { - bool used; - - struct ggml_rwkv_context context; -}; - -// -// compute types -// - -enum ggml_rwkv_task_type { - GGML_RWKV_TASK_INIT = 0, - GGML_RWKV_TASK_COMPUTE, - GGML_RWKV_TASK_FINALIZE, -}; - -struct ggml_rwkv_compute_params { - enum ggml_rwkv_task_type type; - - int ith, nth; - - // work buffer for all threads - size_t wsize; - void * wdata; -}; - -// -// ggml state -// - -struct ggml_rwkv_state { - struct ggml_rwkv_context_container contexts[GGML_RWKV_MAX_CONTEXTS]; -}; - -// global state -static struct ggml_rwkv_state g_state; -static atomic_int g_state_barrier = 0; - -// barrier via spin lock -inline static void ggml_rwkv_critical_section_start(void) { - int processing = atomic_fetch_add(&g_state_barrier, 1); - - while (processing > 0) { - // wait for other threads to finish - atomic_fetch_sub(&g_state_barrier, 1); - sched_yield(); // TODO: reconsider this - processing = atomic_fetch_add(&g_state_barrier, 1); - } -} - -// TODO: make this somehow automatically executed -// some sort of "sentry" mechanism -inline static void ggml_rwkv_critical_section_end(void) { - atomic_fetch_sub(&g_state_barrier, 1); -} - -//////////////////////////////////////////////////////////////////////////////// - -void ggml_rwkv_print_object(const struct ggml_rwkv_object * obj) { - GGML_RWKV_PRINT(" - ggml_rwkv_object: offset = %zu, size = %zu, next = %p\n", - obj->offs, obj->size, (const void *) obj->next); -} - -void ggml_rwkv_print_objects(const struct ggml_rwkv_context * ctx) { - struct ggml_rwkv_object * obj = ctx->objects_begin; - - GGML_RWKV_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); - - while (obj != NULL) { - ggml_rwkv_print_object(obj); - obj = obj->next; - } - - GGML_RWKV_PRINT("%s: --- end ---\n", __func__); -} - -int ggml_rwkv_nelements(const struct ggml_rwkv_tensor * tensor) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; -} - -int ggml_rwkv_nrows(const struct ggml_rwkv_tensor * tensor) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; -} - -size_t ggml_rwkv_nbytes(const struct ggml_rwkv_tensor * tensor) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return (ggml_rwkv_nelements(tensor)*GGML_RWKV_TYPE_SIZE[tensor->type])/GGML_RWKV_BLCK_SIZE[tensor->type]; -} - -int ggml_rwkv_blck_size(enum ggml_rwkv_type type) { - return GGML_RWKV_BLCK_SIZE[type]; -} - -size_t ggml_rwkv_type_size(enum ggml_rwkv_type type) { - return GGML_RWKV_TYPE_SIZE[type]; -} - -float ggml_rwkv_type_sizef(enum ggml_rwkv_type type) { - return ((float)(GGML_RWKV_TYPE_SIZE[type]))/GGML_RWKV_BLCK_SIZE[type]; -} - -size_t ggml_rwkv_element_size(const struct ggml_rwkv_tensor * tensor) { - return GGML_RWKV_TYPE_SIZE[tensor->type]; -} - -static inline bool ggml_rwkv_is_scalar(const struct ggml_rwkv_tensor * tensor) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -static inline bool ggml_rwkv_is_vector(const struct ggml_rwkv_tensor * tensor) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -static inline bool ggml_rwkv_is_matrix(const struct ggml_rwkv_tensor * tensor) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -static inline bool ggml_rwkv_can_mul_mat(const struct ggml_rwkv_tensor * t0, const struct ggml_rwkv_tensor * t1) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return - (t0->ne[0] == t1->ne[0]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); -} - -static inline bool ggml_rwkv_is_transposed(const struct ggml_rwkv_tensor * tensor) { - return tensor->nb[0] > tensor->nb[1]; -} - -static inline bool ggml_rwkv_is_contiguous(const struct ggml_rwkv_tensor * tensor) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == GGML_RWKV_TYPE_SIZE[tensor->type] && - tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/GGML_RWKV_BLCK_SIZE[tensor->type] && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; -} - -static inline bool ggml_rwkv_is_padded_1d(const struct ggml_rwkv_tensor * tensor) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == GGML_RWKV_TYPE_SIZE[tensor->type] && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; -} - -static inline bool ggml_rwkv_are_same_shape(const struct ggml_rwkv_tensor * t0, const struct ggml_rwkv_tensor * t1) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return - (t0->ne[0] == t1->ne[0] ) && - (t0->ne[1] == t1->ne[1] ) && - (t0->ne[2] == t1->ne[2] ) && - (t0->ne[3] == t1->ne[3] ); -} - -// check if t1 can be represented as a repeatition of t0 -static inline bool ggml_rwkv_can_repeat(const struct ggml_rwkv_tensor * t0, const struct ggml_rwkv_tensor * t1) { - static_assert(GGML_RWKV_MAX_DIMS == 4, "GGML_RWKV_MAX_DIMS is not 4 - update this function"); - - return - (t1->ne[0]%t0->ne[0] == 0) && - (t1->ne[1]%t0->ne[1] == 0) && - (t1->ne[2]%t0->ne[2] == 0) && - (t1->ne[3]%t0->ne[3] == 0); -} - -static inline int ggml_rwkv_up32(int n) { - return (n + 31) & ~31; -} - -static inline int ggml_rwkv_up64(int n) { - return (n + 63) & ~63; -} - -static inline int ggml_rwkv_up(int n, int m) { - // assert m is a power of 2 - GGML_RWKV_ASSERT((m & (m - 1)) == 0); - return (n + m - 1) & ~(m - 1); -} - -// assert that pointer is aligned to GGML_RWKV_MEM_ALIGN -#define ggml_rwkv_assert_aligned(ptr) \ - GGML_RWKV_ASSERT(((uintptr_t) (ptr))%GGML_RWKV_MEM_ALIGN == 0) - -//////////////////////////////////////////////////////////////////////////////// - -struct ggml_rwkv_context * ggml_rwkv_init(struct ggml_rwkv_init_params params) { - // make this function thread safe - ggml_rwkv_critical_section_start(); - - static bool is_first_call = true; - - bool run_test_suite = false; - - if (is_first_call) { - // initialize time system (required on Windows) - ggml_rwkv_time_init(); - - // initialize GELU, SILU and EXP F32 tables - { - const uint64_t t_start = ggml_rwkv_time_us(); UNUSED(t_start); - - ggml_rwkv_fp16_t ii; - for (int i = 0; i < (1 << 16); ++i) { - uint16_t ui = i; - memcpy(&ii, &ui, sizeof(ii)); - const float f = table_f32_f16[i] = GGML_RWKV_COMPUTE_FP16_TO_FP32(ii); - table_gelu_f16[i] = GGML_RWKV_FP32_TO_FP16(ggml_rwkv_gelu_f32(f)); - table_silu_f16[i] = GGML_RWKV_FP32_TO_FP16(ggml_rwkv_silu_f32(f)); - table_exp_f16[i] = GGML_RWKV_FP32_TO_FP16(expf(f)); - } - - const uint64_t t_end = ggml_rwkv_time_us(); UNUSED(t_end); - - GGML_RWKV_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); - } - - // initialize g_state - { - const uint64_t t_start = ggml_rwkv_time_us(); UNUSED(t_start); - - g_state = (struct ggml_rwkv_state) { - /*.contexts =*/ { { 0 } }, - }; - - for (int i = 0; i < GGML_RWKV_MAX_CONTEXTS; ++i) { - g_state.contexts[i].used = false; - } - - const uint64_t t_end = ggml_rwkv_time_us(); UNUSED(t_end); - - GGML_RWKV_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); - } - - run_test_suite = true; - - is_first_call = false; - } - - // find non-used context in g_state - struct ggml_rwkv_context * ctx = NULL; - - for (int i = 0; i < GGML_RWKV_MAX_CONTEXTS; i++) { - if (!g_state.contexts[i].used) { - g_state.contexts[i].used = true; - ctx = &g_state.contexts[i].context; - - GGML_RWKV_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); - break; - } - } - - if (ctx == NULL) { - GGML_RWKV_PRINT_DEBUG("%s: no unused context found\n", __func__); - - ggml_rwkv_critical_section_end(); - - return NULL; - } - - *ctx = (struct ggml_rwkv_context) { - /*.mem_size =*/ params.mem_size, - /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), - /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, - /*.mem_buffer_mlocked =*/ false, - /*.n_objects =*/ 0, - /*.objects_begin =*/ NULL, - /*.objects_end =*/ NULL, - /*.scratch =*/ { 0, 0, NULL, }, - /*.scratch_save =*/ { 0, 0, NULL, }, - }; - - GGML_RWKV_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure - - ggml_rwkv_assert_aligned(ctx->mem_buffer); - - GGML_RWKV_PRINT_DEBUG("%s: context initialized\n", __func__); - - ggml_rwkv_critical_section_end(); - - return ctx; -} - -void ggml_rwkv_free(struct ggml_rwkv_context * ctx) { - // make this function thread safe - ggml_rwkv_critical_section_start(); - - bool found = false; - - for (int i = 0; i < GGML_RWKV_MAX_CONTEXTS; i++) { - if (&g_state.contexts[i].context == ctx) { - g_state.contexts[i].used = false; - - GGML_RWKV_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n", - __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size); - -#if GGML_RWKV_MLOCK_SUPPORT - if (ctx->mem_buffer_mlocked) { - if (munlock(ctx->mem_buffer, ctx->mem_size)) { - fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno)); - } - } -#endif - - if (ctx->mem_buffer_owned) { - free(ctx->mem_buffer); - } - - found = true; - break; - } - } - - if (!found) { - GGML_RWKV_PRINT_DEBUG("%s: context not found\n", __func__); - } - - ggml_rwkv_critical_section_end(); -} - -size_t ggml_rwkv_used_mem(const struct ggml_rwkv_context * ctx) { - return ctx->objects_end->offs + ctx->objects_end->size; -} - -size_t ggml_rwkv_set_scratch(struct ggml_rwkv_context * ctx, struct ggml_rwkv_scratch scratch) { - const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; - - ctx->scratch = scratch; - - return result; -} - -bool ggml_rwkv_mlock_supported(void) { - return GGML_RWKV_MLOCK_SUPPORT; -} - -#if GGML_RWKV_MLOCK_SUPPORT -#ifdef __APPLE__ - #define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \ - "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l)." -#else - #define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)." -#endif -bool ggml_rwkv_mlock(struct ggml_rwkv_context * ctx, char ** err_p) { - if (ctx->mem_buffer_mlocked) { - return true; - } - if (mlock(ctx->mem_buffer, ctx->mem_size)) { - int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION, - ctx->mem_size, strerror(errno)); - GGML_RWKV_ASSERT(ret >= 0); - return false; - } - ctx->mem_buffer_mlocked = true; - return true; -} -#else // GGML_RWKV_MLOCK_SUPPORT -bool ggml_rwkv_mlock(struct ggml_rwkv_context * ctx, char ** err_p) { - *err_p = strdup("can't mlock because it's not supported on this system"); - return false; -} -#endif // GGML_RWKV_MLOCK_SUPPORT - -//////////////////////////////////////////////////////////////////////////////// - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor_impl( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int n_dims, - const int* ne, - void* data) { - // always insert objects at the end of the context's memory pool - struct ggml_rwkv_object * obj_cur = ctx->objects_end; - - const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; - const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; - const size_t cur_end = cur_offs + cur_size; - - size_t size_needed = 0; - - if (data == NULL) { - size_needed += GGML_RWKV_TYPE_SIZE[type]*(ne[0]/GGML_RWKV_BLCK_SIZE[type]); - for (int i = 1; i < n_dims; i++) { - size_needed *= ne[i]; - } - // align to GGML_RWKV_MEM_ALIGN - size_needed = ((size_needed + GGML_RWKV_MEM_ALIGN - 1)/GGML_RWKV_MEM_ALIGN)*GGML_RWKV_MEM_ALIGN; - } - - char * const mem_buffer = ctx->mem_buffer; - struct ggml_rwkv_object * const obj_new = (struct ggml_rwkv_object *)(mem_buffer + cur_end); - - if (ctx->scratch.data == NULL || data != NULL) { - size_needed += sizeof(struct ggml_rwkv_tensor); - - if (cur_end + size_needed + GGML_RWKV_OBJECT_SIZE > ctx->mem_size) { - GGML_RWKV_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + size_needed + GGML_RWKV_OBJECT_SIZE, ctx->mem_size); - assert(false); - return NULL; - } - - *obj_new = (struct ggml_rwkv_object) { - .offs = cur_end + GGML_RWKV_OBJECT_SIZE, - .size = size_needed, - .next = NULL, - }; - } else { - if (ctx->scratch.offs + size_needed > ctx->scratch.size) { - GGML_RWKV_PRINT("%s: not enough space in the scratch memory\n", __func__); - assert(false); - return NULL; - } - - if (cur_end + sizeof(struct ggml_rwkv_tensor) + GGML_RWKV_OBJECT_SIZE > ctx->mem_size) { - GGML_RWKV_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + sizeof(struct ggml_rwkv_tensor) + GGML_RWKV_OBJECT_SIZE, ctx->mem_size); - assert(false); - return NULL; - } - - data = (char * const) ctx->scratch.data + ctx->scratch.offs; - - *obj_new = (struct ggml_rwkv_object) { - .offs = cur_end + GGML_RWKV_OBJECT_SIZE, - .size = sizeof(struct ggml_rwkv_tensor), - .next = NULL, - }; - - //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed); - - ctx->scratch.offs += size_needed; - } - - if (obj_cur != NULL) { - obj_cur->next = obj_new; - } else { - // this is the first object in this context - ctx->objects_begin = obj_new; - } - - ctx->objects_end = obj_new; - - //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); - - struct ggml_rwkv_tensor * const result = (struct ggml_rwkv_tensor *)(mem_buffer + obj_new->offs); - - ggml_rwkv_assert_aligned(result); - - *result = (struct ggml_rwkv_tensor) { - /*.type =*/ type, - /*.n_dims =*/ n_dims, - /*.ne =*/ { 1, 1, 1, 1 }, - /*.nb =*/ { 0, 0, 0, 0 }, - /*.op =*/ GGML_RWKV_OP_NONE, - /*.is_param =*/ false, - /*.grad =*/ NULL, - /*.src0 =*/ NULL, - /*.src1 =*/ NULL, - /*.opt =*/ { NULL }, - /*.n_tasks =*/ 0, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - /*.data =*/ data == NULL ? (void *)(result + 1) : data, - /*.pad =*/ { 0 }, - }; - - ggml_rwkv_assert_aligned(result->data); - - for (int i = 0; i < n_dims; i++) { - result->ne[i] = ne[i]; - } - - result->nb[0] = GGML_RWKV_TYPE_SIZE[type]; - result->nb[1] = result->nb[0]*(result->ne[0]/GGML_RWKV_BLCK_SIZE[type]); - for (int i = 2; i < GGML_RWKV_MAX_DIMS; i++) { - result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; - } - - ctx->n_objects++; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int n_dims, - const int * ne) { - return ggml_rwkv_new_tensor_impl(ctx, type, n_dims, ne, NULL); -} - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor_1d( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int ne0) { - return ggml_rwkv_new_tensor(ctx, type, 1, &ne0); -} - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor_2d( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int ne0, - int ne1) { - const int ne[2] = { ne0, ne1 }; - return ggml_rwkv_new_tensor(ctx, type, 2, ne); -} - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor_3d( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int ne0, - int ne1, - int ne2) { - const int ne[3] = { ne0, ne1, ne2 }; - return ggml_rwkv_new_tensor(ctx, type, 3, ne); -} - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor_4d( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int ne0, - int ne1, - int ne2, - int ne3) { - const int ne[4] = { ne0, ne1, ne2, ne3 }; - return ggml_rwkv_new_tensor(ctx, type, 4, ne); -} - -struct ggml_rwkv_tensor * ggml_rwkv_new_i32(struct ggml_rwkv_context * ctx, int32_t value) { - ctx->scratch_save = ctx->scratch; - ctx->scratch.data = NULL; - - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_I32, 1); - - ctx->scratch = ctx->scratch_save; - - ggml_rwkv_set_i32(result, value); - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_new_f32(struct ggml_rwkv_context * ctx, float value) { - ctx->scratch_save = ctx->scratch; - ctx->scratch.data = NULL; - - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, 1); - - ctx->scratch = ctx->scratch_save; - - ggml_rwkv_set_f32(result, value); - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_dup_tensor(struct ggml_rwkv_context * ctx, const struct ggml_rwkv_tensor * src) { - return ggml_rwkv_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL); -} - -struct ggml_rwkv_tensor * ggml_rwkv_set_zero(struct ggml_rwkv_tensor * tensor) { - memset(tensor->data, 0, ggml_rwkv_nbytes(tensor)); - return tensor; -} - -struct ggml_rwkv_tensor * ggml_rwkv_set_i32 (struct ggml_rwkv_tensor * tensor, int32_t value) { - const int n = ggml_rwkv_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; - - char * const data = tensor->data; - - switch (tensor->type) { - case GGML_RWKV_TYPE_Q4_0: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1_O: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(ggml_rwkv_fp16_t)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_f16(nc, (ggml_rwkv_fp16_t *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } - - return tensor; -} - -struct ggml_rwkv_tensor * ggml_rwkv_set_f32(struct ggml_rwkv_tensor * tensor, float value) { - const int n = ggml_rwkv_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; - - char * const data = tensor->data; - - switch (tensor->type) { - case GGML_RWKV_TYPE_Q4_0: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1_O: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(ggml_rwkv_fp16_t)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_f16(nc, (ggml_rwkv_fp16_t *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } - - return tensor; -} - -int32_t ggml_rwkv_get_i32_1d(const struct ggml_rwkv_tensor * tensor, int i) { - switch (tensor->type) { - case GGML_RWKV_TYPE_Q4_0: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1_O: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_I8: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int8_t)); - return ((int8_t *)(tensor->data))[i]; - } break; - case GGML_RWKV_TYPE_I16: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int16_t)); - return ((int16_t *)(tensor->data))[i]; - } break; - case GGML_RWKV_TYPE_I32: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int32_t)); - return ((int32_t *)(tensor->data))[i]; - } break; - case GGML_RWKV_TYPE_F16: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(ggml_rwkv_fp16_t)); - return GGML_RWKV_FP16_TO_FP32(((ggml_rwkv_fp16_t *)(tensor->data))[i]); - } break; - case GGML_RWKV_TYPE_F32: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(float)); - return ((float *)(tensor->data))[i]; - } break; - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } - - return 0.0f; -} - -void ggml_rwkv_set_i32_1d(const struct ggml_rwkv_tensor * tensor, int i, int32_t value) { - switch (tensor->type) { - case GGML_RWKV_TYPE_Q4_0: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1_O: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_I8: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; - } break; - case GGML_RWKV_TYPE_I16: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; - } break; - case GGML_RWKV_TYPE_I32: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; - } break; - case GGML_RWKV_TYPE_F16: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(ggml_rwkv_fp16_t)); - ((ggml_rwkv_fp16_t *)(tensor->data))[i] = GGML_RWKV_FP32_TO_FP16(value); - } break; - case GGML_RWKV_TYPE_F32: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; - } break; - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -float ggml_rwkv_get_f32_1d(const struct ggml_rwkv_tensor * tensor, int i) { - switch (tensor->type) { - case GGML_RWKV_TYPE_Q4_0: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1_O: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_I8: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int8_t)); - return ((int8_t *)(tensor->data))[i]; - } break; - case GGML_RWKV_TYPE_I16: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int16_t)); - return ((int16_t *)(tensor->data))[i]; - } break; - case GGML_RWKV_TYPE_I32: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int32_t)); - return ((int32_t *)(tensor->data))[i]; - } break; - case GGML_RWKV_TYPE_F16: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(ggml_rwkv_fp16_t)); - return GGML_RWKV_FP16_TO_FP32(((ggml_rwkv_fp16_t *)(tensor->data))[i]); - } break; - case GGML_RWKV_TYPE_F32: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(float)); - return ((float *)(tensor->data))[i]; - } break; - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } - - return 0.0f; -} - -void ggml_rwkv_set_f32_1d(const struct ggml_rwkv_tensor * tensor, int i, float value) { - switch (tensor->type) { - case GGML_RWKV_TYPE_Q4_0: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_Q4_1_O: - { - GGML_RWKV_ASSERT(false); - } break; - case GGML_RWKV_TYPE_I8: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; - } break; - case GGML_RWKV_TYPE_I16: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; - } break; - case GGML_RWKV_TYPE_I32: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; - } break; - case GGML_RWKV_TYPE_F16: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(ggml_rwkv_fp16_t)); - ((ggml_rwkv_fp16_t *)(tensor->data))[i] = GGML_RWKV_FP32_TO_FP16(value); - } break; - case GGML_RWKV_TYPE_F32: - { - GGML_RWKV_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; - } break; - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -void * ggml_rwkv_get_data(const struct ggml_rwkv_tensor * tensor) { - return tensor->data; -} - -float * ggml_rwkv_get_data_f32(const struct ggml_rwkv_tensor * tensor) { - assert(tensor->type == GGML_RWKV_TYPE_F32); - return (float *)(tensor->data); -} - -struct ggml_rwkv_tensor * ggml_rwkv_view_tensor( - struct ggml_rwkv_context * ctx, - const struct ggml_rwkv_tensor * src) { - return ggml_rwkv_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data); -} - -//////////////////////////////////////////////////////////////////////////////// - -// ggml_rwkv_dup - -struct ggml_rwkv_tensor * ggml_rwkv_dup_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_DUP; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_dup( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_dup_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_dup_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_dup_impl(ctx, a, true); -} - -// ggml_rwkv_add - -struct ggml_rwkv_tensor * ggml_rwkv_add_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b, - bool inplace) { - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_ADD; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_add( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_add_impl(ctx, a, b, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_add_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_add_impl(ctx, a, b, true); -} - -// ggml_rwkv_sub - -struct ggml_rwkv_tensor * ggml_rwkv_sub_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b, - bool inplace) { - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_SUB; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_sub( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_sub_impl(ctx, a, b, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_sub_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_sub_impl(ctx, a, b, true); -} - -// ggml_rwkv_mul - -struct ggml_rwkv_tensor * ggml_rwkv_mul_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b, - bool inplace) { - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - if (inplace) { - GGML_RWKV_ASSERT(is_node == false); - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_MUL; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_mul( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_mul_impl(ctx, a, b, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_mul_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_mul_impl(ctx, a, b, true); -} - -// ggml_rwkv_div - -struct ggml_rwkv_tensor * ggml_rwkv_div_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b, - bool inplace) { - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - if (inplace) { - GGML_RWKV_ASSERT(is_node == false); - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_DIV; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_div( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_div_impl(ctx, a, b, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_div_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_div_impl(ctx, a, b, true); -} - -// ggml_rwkv_sqr - -struct ggml_rwkv_tensor * ggml_rwkv_sqr_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_SQR; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_sqr( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_sqr_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_sqr_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_sqr_impl(ctx, a, true); -} - -// ggml_rwkv_sqrt - -struct ggml_rwkv_tensor * ggml_rwkv_sqrt_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_SQRT; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_sqrt( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_sqrt_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_sqrt_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_sqrt_impl(ctx, a, true); -} - -// ggml_rwkv_sum - -struct ggml_rwkv_tensor * ggml_rwkv_sum( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_1d(ctx, a->type, 1); - - result->op = GGML_RWKV_OP_SUM; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_rwkv_mean - -struct ggml_rwkv_tensor * ggml_rwkv_mean( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - bool is_node = false; - - if (a->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement - is_node = true; - } - - int ne[GGML_RWKV_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor(ctx, GGML_RWKV_TYPE_F32, a->n_dims, ne); - - result->op = GGML_RWKV_OP_MEAN; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_rwkv_repeat - -struct ggml_rwkv_tensor * ggml_rwkv_repeat( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - GGML_RWKV_ASSERT(ggml_rwkv_can_repeat(a, b)); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (ggml_rwkv_are_same_shape(a, b) && !is_node) { - return a; - } - - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor(ctx, a->type, b->n_dims, b->ne); - - result->op = GGML_RWKV_OP_REPEAT; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_rwkv_abs - -struct ggml_rwkv_tensor * ggml_rwkv_abs_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_ABS; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_abs( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_abs_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_abs_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_abs_impl(ctx, a, true); -} - - -// ggml_rwkv_sgn - -struct ggml_rwkv_tensor * ggml_rwkv_sgn_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_SGN; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_sgn( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_sgn_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_sgn_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_sgn_impl(ctx, a, true); -} - -// ggml_rwkv_neg - -struct ggml_rwkv_tensor * ggml_rwkv_neg_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_NEG; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_neg( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_neg_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_neg_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_neg_impl(ctx, a, true); -} - -// ggml_rwkv_exp - -struct ggml_rwkv_tensor * ggml_rwkv_exp( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - struct ggml_rwkv_tensor * result = ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_EXP; - result->grad = a->grad ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_rwkv_1_minus_x - -struct ggml_rwkv_tensor * ggml_rwkv_1_minus_x( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - struct ggml_rwkv_tensor * result = ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_1_MINUS_X; - result->grad = a->grad ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_rwkv_max - -struct ggml_rwkv_tensor * ggml_rwkv_max( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(a, b)); - - struct ggml_rwkv_tensor * result = ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_MAX; - result->grad = (a->grad || b->grad) ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_rwkv_step - -struct ggml_rwkv_tensor * ggml_rwkv_step_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_STEP; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_step( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_step_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_step_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_step_impl(ctx, a, true); -} - -// ggml_rwkv_relu - -struct ggml_rwkv_tensor * ggml_rwkv_relu_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_RELU; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_relu( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_relu_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_relu_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_relu_impl(ctx, a, true); -} - -// ggml_rwkv_gelu - -struct ggml_rwkv_tensor * ggml_rwkv_gelu_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_GELU; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_gelu( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_gelu_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_gelu_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_gelu_impl(ctx, a, true); -} - -// ggml_rwkv_sigmoid - -struct ggml_rwkv_tensor * ggml_rwkv_sigmoid( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - struct ggml_rwkv_tensor * result = ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_SIGMOID; - result->grad = a->grad ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_rwkv_silu - -struct ggml_rwkv_tensor * ggml_rwkv_silu_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_SILU; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_silu( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_silu_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_silu_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_silu_impl(ctx, a, true); -} - -// ggml_rwkv_norm - -struct ggml_rwkv_tensor * ggml_rwkv_norm_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_NORM; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; // TODO: maybe store epsilon here? - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_norm( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_norm_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_norm_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_norm_impl(ctx, a, true); -} - -struct ggml_rwkv_tensor * ggml_rwkv_rms_norm_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - - result->op = GGML_RWKV_OP_RMS_NORM; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; // TODO: maybe store epsilon here? - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_rms_norm( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_rms_norm_impl(ctx, a, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_rms_norm_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - return ggml_rwkv_rms_norm_impl(ctx, a, true); -} - -// ggml_rwkv_mul_mat - -struct ggml_rwkv_tensor * ggml_rwkv_mul_mat( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - GGML_RWKV_ASSERT(ggml_rwkv_can_mul_mat(a, b)); - GGML_RWKV_ASSERT(!ggml_rwkv_is_transposed(a)); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor(ctx, GGML_RWKV_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); - - result->op = GGML_RWKV_OP_MUL_MAT; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_rwkv_scale - -struct ggml_rwkv_tensor * ggml_rwkv_scale_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b, - bool inplace) { - GGML_RWKV_ASSERT(ggml_rwkv_is_scalar(b)); - GGML_RWKV_ASSERT(ggml_rwkv_is_padded_1d(a)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - struct ggml_rwkv_tensor * result = ggml_rwkv_view_tensor(ctx, a); - - result->op = GGML_RWKV_OP_SCALE; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_scale( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_scale_impl(ctx, a, b, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_scale_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_scale_impl(ctx, a, b, true); -} - -// ggml_rwkv_cpy - -struct ggml_rwkv_tensor * ggml_rwkv_cpy_impl( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b, - bool inplace) { - GGML_RWKV_ASSERT(ggml_rwkv_nelements(a) == ggml_rwkv_nelements(b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // make a view of the destination - struct ggml_rwkv_tensor * result = ggml_rwkv_view_tensor(ctx, b); - - result->op = GGML_RWKV_OP_CPY; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_cpy( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_cpy_impl(ctx, a, b, false); -} - -struct ggml_rwkv_tensor * ggml_rwkv_cpy_inplace( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - return ggml_rwkv_cpy_impl(ctx, a, b, true); -} - -// ggml_rwkv_reshape - -struct ggml_rwkv_tensor * ggml_rwkv_reshape( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(a)); - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(b)); - GGML_RWKV_ASSERT(ggml_rwkv_nelements(a) == ggml_rwkv_nelements(b)); - - bool is_node = false; - - if (a->grad || b->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data); - - result->op = GGML_RWKV_OP_RESHAPE; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_reshape_2d( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int ne0, - int ne1) { - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(a)); - GGML_RWKV_ASSERT(ggml_rwkv_nelements(a) == ne0*ne1); - - bool is_node = false; - - if (a->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int ne[2] = { ne0, ne1 }; - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_impl(ctx, a->type, 2, ne, a->data); - - result->op = GGML_RWKV_OP_RESHAPE; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_rwkv_tensor * ggml_rwkv_reshape_3d( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int ne0, - int ne1, - int ne2) { - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(a)); - GGML_RWKV_ASSERT(ggml_rwkv_nelements(a) == ne0*ne1*ne2); - - bool is_node = false; - - if (a->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int ne[3] = { ne0, ne1, ne2 }; - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_impl(ctx, a->type, 3, ne, a->data); - - result->op = GGML_RWKV_OP_RESHAPE; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_rwkv_view_1d - -struct ggml_rwkv_tensor * ggml_rwkv_view_1d( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int ne0, - size_t offset) { - if (a->grad) { - GGML_RWKV_ASSERT(false); // gradient propagation is not supported - } - - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); - - result->op = GGML_RWKV_OP_VIEW; - result->grad = NULL; - result->src0 = a; - result->src1 = NULL; // TODO: maybe store the offset here? - - return result; -} - -// ggml_rwkv_view_2d - -struct ggml_rwkv_tensor * ggml_rwkv_view_2d( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int ne0, - int ne1, - size_t nb1, - size_t offset) { - if (a->grad) { - GGML_RWKV_ASSERT(false); // gradient propagation is not supported - } - - const int ne[GGML_RWKV_MAX_DIMS] = { ne0, ne1, 1, 1 }; - - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); - - result->nb[1] = nb1; - result->nb[2] = result->nb[1]*ne1; - result->nb[3] = result->nb[2]; - - result->op = GGML_RWKV_OP_VIEW; - result->grad = NULL; - result->src0 = a; - result->src1 = NULL; // TODO: maybe store the offset here? - - return result; -} - -// ggml_rwkv_permute - -struct ggml_rwkv_tensor * ggml_rwkv_permute( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int axis0, - int axis1, - int axis2, - int axis3) { - GGML_RWKV_ASSERT(axis0 >= 0 && axis0 < GGML_RWKV_MAX_DIMS); - GGML_RWKV_ASSERT(axis1 >= 0 && axis1 < GGML_RWKV_MAX_DIMS); - GGML_RWKV_ASSERT(axis2 >= 0 && axis2 < GGML_RWKV_MAX_DIMS); - GGML_RWKV_ASSERT(axis3 >= 0 && axis3 < GGML_RWKV_MAX_DIMS); - - GGML_RWKV_ASSERT(axis0 != axis1); - GGML_RWKV_ASSERT(axis0 != axis2); - GGML_RWKV_ASSERT(axis0 != axis3); - GGML_RWKV_ASSERT(axis1 != axis2); - GGML_RWKV_ASSERT(axis1 != axis3); - GGML_RWKV_ASSERT(axis2 != axis3); - - bool is_node = false; - - if (a->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - struct ggml_rwkv_tensor * result = ggml_rwkv_view_tensor(ctx, a); - - int ne[GGML_RWKV_MAX_DIMS]; - int nb[GGML_RWKV_MAX_DIMS]; - - ne[axis0] = a->ne[0]; - ne[axis1] = a->ne[1]; - ne[axis2] = a->ne[2]; - ne[axis3] = a->ne[3]; - - nb[axis0] = a->nb[0]; - nb[axis1] = a->nb[1]; - nb[axis2] = a->nb[2]; - nb[axis3] = a->nb[3]; - - result->ne[0] = ne[0]; - result->ne[1] = ne[1]; - result->ne[2] = ne[2]; - result->ne[3] = ne[3]; - - result->nb[0] = nb[0]; - result->nb[1] = nb[1]; - result->nb[2] = nb[2]; - result->nb[3] = nb[3]; - - result->op = GGML_RWKV_OP_PERMUTE; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; // TODO: maybe store the permutation here? - - return result; -} - -// ggml_rwkv_transpose - -struct ggml_rwkv_tensor * ggml_rwkv_transpose( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - bool is_node = false; - - if (a->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - struct ggml_rwkv_tensor * result = ggml_rwkv_view_tensor(ctx, a); - - result->ne[0] = a->ne[1]; - result->ne[1] = a->ne[0]; - - result->nb[0] = a->nb[1]; - result->nb[1] = a->nb[0]; - - result->op = GGML_RWKV_OP_TRANSPOSE; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_rwkv_get_rows - -struct ggml_rwkv_tensor * ggml_rwkv_get_rows( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - GGML_RWKV_ASSERT(ggml_rwkv_is_matrix(a) && ggml_rwkv_is_vector(b) && b->type == GGML_RWKV_TYPE_I32); - - bool is_node = false; - - if (a->grad || b->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: implement non F32 return - //struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor_2d(ctx, GGML_RWKV_TYPE_F32, a->ne[0], b->ne[0]); - - result->op = GGML_RWKV_OP_GET_ROWS; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_rwkv_diag_mask_inf - -struct ggml_rwkv_tensor * ggml_rwkv_diag_mask_inf( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int n_past) { - bool is_node = false; - - if (a->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - struct ggml_rwkv_tensor * result = ggml_rwkv_view_tensor(ctx, a); - struct ggml_rwkv_tensor * b = ggml_rwkv_new_i32(ctx, n_past); - - result->op = GGML_RWKV_OP_DIAG_MASK_INF; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_rwkv_soft_max - -struct ggml_rwkv_tensor * ggml_rwkv_soft_max( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a) { - bool is_node = false; - - if (a->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - struct ggml_rwkv_tensor * result = ggml_rwkv_view_tensor(ctx, a); - - result->op = GGML_RWKV_OP_SOFT_MAX; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_rwkv_rope - -struct ggml_rwkv_tensor * ggml_rwkv_rope( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int n_past, - int n_dims, - int mode) { - GGML_RWKV_ASSERT(n_past >= 0); - bool is_node = false; - - if (a->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_rwkv_tensor * result = inplace ? ggml_rwkv_view_tensor(ctx, a) : ggml_rwkv_dup_tensor(ctx, a); - struct ggml_rwkv_tensor * result = ggml_rwkv_view_tensor(ctx, a); - - struct ggml_rwkv_tensor * b = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_I32, 3); - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = n_dims; - ((int32_t *) b->data)[2] = mode; - - result->op = GGML_RWKV_OP_ROPE; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_rwkv_conv_1d_1s - -struct ggml_rwkv_tensor * ggml_rwkv_conv_1d_1s( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - GGML_RWKV_ASSERT(ggml_rwkv_is_matrix(b)); - GGML_RWKV_ASSERT(a->ne[1] == b->ne[1]); - GGML_RWKV_ASSERT(a->ne[3] == 1); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int ne[4] = { b->ne[0], a->ne[2], 1, 1, }; - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor(ctx, GGML_RWKV_TYPE_F32, 2, ne); - - result->op = GGML_RWKV_OP_CONV_1D_1S; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_rwkv_conv_1d_2s - -struct ggml_rwkv_tensor * ggml_rwkv_conv_1d_2s( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b) { - GGML_RWKV_ASSERT(ggml_rwkv_is_matrix(b)); - GGML_RWKV_ASSERT(a->ne[1] == b->ne[1]); - GGML_RWKV_ASSERT(a->ne[3] == 1); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, }; - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor(ctx, GGML_RWKV_TYPE_F32, 2, ne); - - result->op = GGML_RWKV_OP_CONV_1D_2S; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_rwkv_flash_attn - -struct ggml_rwkv_tensor * ggml_rwkv_flash_attn( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * q, - struct ggml_rwkv_tensor * k, - struct ggml_rwkv_tensor * v, - bool masked) { - GGML_RWKV_ASSERT(ggml_rwkv_can_mul_mat(k, q)); - // TODO: check if vT can be multiplied by (k*qT) - - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - //struct ggml_rwkv_tensor * result = ggml_rwkv_dup_tensor(ctx, q); - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor(ctx, GGML_RWKV_TYPE_F32, 4, q->ne); - - result->op = GGML_RWKV_OP_FLASH_ATTN; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = q; - result->src1 = k; - result->opt[0] = v; - result->opt[1] = ggml_rwkv_new_i32(ctx, masked ? 1 : 0); - - return result; -} - -// ggml_rwkv_flash_ff - -struct ggml_rwkv_tensor * ggml_rwkv_flash_ff( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b0, - struct ggml_rwkv_tensor * b1, - struct ggml_rwkv_tensor * c0, - struct ggml_rwkv_tensor * c1) { - GGML_RWKV_ASSERT(ggml_rwkv_can_mul_mat(b0, a)); - // TODO: more checks - - bool is_node = false; - - if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { - GGML_RWKV_ASSERT(false); // TODO: implement backward - is_node = true; - } - - //struct ggml_rwkv_tensor * result = ggml_rwkv_dup_tensor(ctx, a); - struct ggml_rwkv_tensor * result = ggml_rwkv_new_tensor(ctx, GGML_RWKV_TYPE_F32, 4, a->ne); - - result->op = GGML_RWKV_OP_FLASH_FF; - result->grad = is_node ? ggml_rwkv_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b0; - result->opt[0] = b1; - result->opt[1] = c0; - result->opt[2] = c1; - - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -void ggml_rwkv_set_param( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * tensor) { - tensor->is_param = true; - - GGML_RWKV_ASSERT(tensor->grad == NULL); - tensor->grad = ggml_rwkv_dup_tensor(ctx, tensor); -} - -// ggml_rwkv_compute_forward_dup - -static void ggml_rwkv_compute_forward_dup_f16( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(params->ith == 0); - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(dst)); - GGML_RWKV_ASSERT(ggml_rwkv_nelements(dst) == ggml_rwkv_nelements(src0)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - if (ggml_rwkv_is_contiguous(src0) && src0->type == dst->type) { - memcpy(dst->data, src0->data, ggml_rwkv_nelements(dst) * GGML_RWKV_TYPE_SIZE[src0->type]); - return; - } - - if (src0->nb[0] == sizeof(ggml_rwkv_fp16_t)) { - if (dst->type == GGML_RWKV_TYPE_F16) { - size_t id = 0; - const size_t rs = ne00*nb00; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - char * dst_ptr = (char *) dst->data + id*rs; - - memcpy(dst_ptr, src0_ptr, rs); - - id++; - } - } - } - } else if (dst->type == GGML_RWKV_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_rwkv_fp16_t * src0_ptr = (ggml_rwkv_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_RWKV_FP16_TO_FP32(*src0_ptr); - id++; - } - } - } - } - } else { - GGML_RWKV_ASSERT(false); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_RWKV_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_rwkv_fp16_t * src0_ptr = (ggml_rwkv_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_RWKV_FP16_TO_FP32(*src0_ptr); - id++; - } - } - } - } - } else if (dst->type == GGML_RWKV_TYPE_F16) { - size_t id = 0; - ggml_rwkv_fp16_t * dst_ptr = (ggml_rwkv_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_rwkv_fp16_t * src0_ptr = (ggml_rwkv_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - } - } - } else { - GGML_RWKV_ASSERT(false); // TODO: implement - } - } -} - -static void ggml_rwkv_compute_forward_dup_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(params->ith == 0); - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(dst)); - GGML_RWKV_ASSERT(ggml_rwkv_nelements(dst) == ggml_rwkv_nelements(src0)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - if (ggml_rwkv_is_contiguous(src0) && src0->type == dst->type) { - memcpy(dst->data, src0->data, ggml_rwkv_nelements(dst) * GGML_RWKV_TYPE_SIZE[src0->type]); - return; - } - - if (src0->nb[0] == sizeof(float)) { - if (dst->type == GGML_RWKV_TYPE_F32) { - size_t id = 0; - const size_t rs = ne00*nb00; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - char * dst_ptr = (char *) dst->data + id*rs; - - memcpy(dst_ptr, src0_ptr, rs); - - id++; - } - } - } - } else if (dst->type == GGML_RWKV_TYPE_F16) { - size_t id = 0; - ggml_rwkv_fp16_t * dst_ptr = (ggml_rwkv_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_RWKV_FP32_TO_FP16(*src0_ptr); - id++; - } - } - } - } - } else { - GGML_RWKV_ASSERT(false); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_RWKV_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - } - } - } else if (dst->type == GGML_RWKV_TYPE_F16) { - size_t id = 0; - ggml_rwkv_fp16_t * dst_ptr = (ggml_rwkv_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_RWKV_FP32_TO_FP16(*src0_ptr); - id++; - } - } - } - } - } else { - GGML_RWKV_ASSERT(false); // TODO: implement - } - } -} - -static void ggml_rwkv_compute_forward_dup( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F16: - { - ggml_rwkv_compute_forward_dup_f16(params, src0, dst); - } break; - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_dup_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_add - -static void ggml_rwkv_compute_forward_add_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(src0, src1) && ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - - GGML_RWKV_ASSERT( nb0 == sizeof(float)); - GGML_RWKV_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - const int j0 = (n/nth)*ith; - const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1); - - for (int j = j0; j < j1; j++) { - ggml_rwkv_vec_add_f32(nc, - (float *) ((char *) dst->data + j*nb1), - (float *) ((char *) src0->data + j*nb01), - (float *) ((char *) src1->data + j*nb11)); - } - } else { - // src1 is not contiguous - for (int j = ith; j < n; j += nth) { - float * dst_ptr = (float *) ((char *) dst->data + j*nb1); - float * src0_ptr = (float *) ((char *) src0->data + j*nb01); - for (int i = 0; i < nc; i++) { - float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); - - dst_ptr[i] = src0_ptr[i] + *src1_ptr; - } - } - } -} - -static void ggml_rwkv_compute_forward_add( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_add_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_sub - -static void ggml_rwkv_compute_forward_sub_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, src1) && ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_sub_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_sub( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_sub_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_mul - -static void ggml_rwkv_compute_forward_mul_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, src1) && ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_mul_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_mul( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_mul_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_div - -static void ggml_rwkv_compute_forward_div_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, src1) && ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_div_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_div( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_div_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_sqr - -static void ggml_rwkv_compute_forward_sqr_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_sqr_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_sqr( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_sqr_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_sqrt - -static void ggml_rwkv_compute_forward_sqrt_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_sqrt_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_sqrt( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_sqrt_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_sum - -static void ggml_rwkv_compute_forward_sum_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_is_scalar(dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - assert(ggml_rwkv_is_scalar(dst)); - assert(src0->nb[0] == sizeof(float)); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - ggml_rwkv_vec_sum_f32(ne00, - (float *) (dst->data), - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - } - } - } -} - -static void ggml_rwkv_compute_forward_sum( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_sum_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_mean - -static void ggml_rwkv_compute_forward_mean_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - - assert(ne0 == 1); - assert(ne1 == ne01); - assert(ne2 == ne02); - assert(ne3 == ne03); - - UNUSED(ne0); - UNUSED(ne1); - UNUSED(ne2); - UNUSED(ne3); - - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - ggml_rwkv_vec_sum_f32(ne00, - (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - - *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; - } - } - } -} - -static void ggml_rwkv_compute_forward_mean( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_mean_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_repeat - -static void ggml_rwkv_compute_forward_repeat_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_can_repeat(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // TODO: implement support for rank > 2 tensors - assert(src0->ne[2] == 1); - assert(src0->ne[3] == 1); - assert( dst->ne[2] == 1); - assert( dst->ne[3] == 1); - - const int nc = dst->ne[0]; - const int nr = dst->ne[1]; - const int nc0 = src0->ne[0]; - const int nr0 = src0->ne[1]; - const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_rwkv_can_repeat - const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_rwkv_can_repeat - - // TODO: support for transposed / permuted tensors - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - // TODO: maybe this is not optimal? - for (int i = 0; i < nrr; i++) { - for (int j = 0; j < ncr; j++) { - for (int k = 0; k < nr0; k++) { - ggml_rwkv_vec_cpy_f32(nc0, - (float *) ((char *) dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])), - (float *) ((char *) src0->data + ( k)*(src0->nb[1]))); - } - } - } -} - -static void ggml_rwkv_compute_forward_repeat( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_repeat_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_abs - -static void ggml_rwkv_compute_forward_abs_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_abs_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_abs( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_abs_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_sgn - -static void ggml_rwkv_compute_forward_sgn_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_sgn_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_sgn( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_sgn_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_neg - -static void ggml_rwkv_compute_forward_neg_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_neg_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_neg( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_neg_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_exp - -static void ggml_rwkv_compute_forward_exp_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_exp_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_exp( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_exp_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_1_minus_x - -static void ggml_rwkv_compute_forward_1_minus_x_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_1_minus_x_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_1_minus_x( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_1_minus_x_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_max - -static void ggml_rwkv_compute_forward_max_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, src1) && ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_element_wise_max_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_max( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_max_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_step - -static void ggml_rwkv_compute_forward_step_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_step_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_step( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_step_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_relu - -static void ggml_rwkv_compute_forward_relu_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_relu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_relu( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_relu_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_gelu - -static void ggml_rwkv_compute_forward_gelu_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(src0)); - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(dst)); - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_rwkv_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_rwkv_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_rwkv_compute_forward_gelu( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_gelu_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } - - //printf("XXXXXXXX gelu\n"); -} - -// ggml_rwkv_compute_forward_sigmoid - -static void ggml_rwkv_compute_forward_sigmoid_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_rwkv_vec_sigmoid_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_rwkv_compute_forward_sigmoid( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_sigmoid_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_silu - -static void ggml_rwkv_compute_forward_silu_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(src0)); - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(dst)); - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_rwkv_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_rwkv_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_rwkv_compute_forward_silu( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_silu_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - - -// ggml_rwkv_compute_forward_norm - -static void ggml_rwkv_compute_forward_norm_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - GGML_RWKV_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - const float eps = 1e-5f; // TODO: make this a parameter - - // TODO: optimize - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - ggml_rwkv_float sum = 0.0; - for (int i00 = 0; i00 < ne00; i00++) { - sum += (ggml_rwkv_float)x[i00]; - } - - float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - ggml_rwkv_float sum2 = 0.0; - for (int i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (ggml_rwkv_float)(v*v); - } - - float variance = sum2/ne00; - const float scale = 1.0f/sqrtf(variance + eps); - - ggml_rwkv_vec_scale_f32(ne00, y, scale); - } - } - } -} - -static void ggml_rwkv_compute_forward_norm( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_norm_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -static void ggml_rwkv_compute_forward_rms_norm_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - GGML_RWKV_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - const float eps = 1e-6f; // TODO: make this a parameter - - // TODO: optimize - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - ggml_rwkv_float sum = 0.0; - for (int i00 = 0; i00 < ne00; i00++) { - sum += (ggml_rwkv_float)(x[i00] * x[i00]); - } - - float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - // for (int i00 = 0; i00 < ne00; i00++) { - // y[i00] = x[i00]; - // } - - const float scale = 1.0f/sqrtf(mean + eps); - - ggml_rwkv_vec_scale_f32(ne00, y, scale); - } - } - } -} - -static void ggml_rwkv_compute_forward_rms_norm( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_rms_norm_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - - -// ggml_rwkv_compute_forward_mul_mat - -#if defined(GGML_RWKV_USE_ACCELERATE) || defined(GGML_RWKV_USE_OPENBLAS) -// helper function to determine if it is better to use BLAS or not -// for large matrices, BLAS is faster -static bool ggml_rwkv_compute_forward_mul_mat_use_blas( - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - //const int ne00 = src0->ne[0]; - //const int ne01 = src0->ne[1]; - - const int ne10 = src1->ne[0]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - - // TODO: find the optimal values for these - if (ggml_rwkv_is_contiguous(src0) && - ggml_rwkv_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) { - - /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ - return true; - } - - return false; -} -#endif - -static void ggml_rwkv_compute_forward_mul_mat_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - -#if defined(GGML_RWKV_USE_ACCELERATE) || defined(GGML_RWKV_USE_OPENBLAS) - const int ne10 = src1->ne[0]; -#endif - const int ne11 = src1->ne[1]; -#ifndef NDEBUG - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - - const int nb00 = src0->nb[0]; -#endif - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - -#ifndef NDEBUG - const int nb10 = src1->nb[0]; -#endif - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - assert(ne02 == ne12); - assert(ne03 == ne13); - assert(ne2 == ne12); - assert(ne3 == ne13); - - // we don't support permuted src0 or src1 - assert(nb00 == sizeof(float)); - assert(nb10 == sizeof(float)); - - // dst cannot be transposed or permuted - assert(nb0 == sizeof(float)); - assert(nb0 <= nb1); - assert(nb1 <= nb2); - assert(nb2 <= nb3); - - assert(ne0 == ne01); - assert(ne1 == ne11); - assert(ne2 == ne02); - assert(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - -#if defined(GGML_RWKV_USE_ACCELERATE) || defined(GGML_RWKV_USE_OPENBLAS) - if (ggml_rwkv_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - if (params->ith != 0) { - return; - } - - if (params->type == GGML_RWKV_TASK_INIT) { - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); - const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); - - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - - // zT = y * xT - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, - 0.0f, d, ne01); - } - } - - //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_rwkv_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); - - return; - } -#endif - - if (params->type == GGML_RWKV_TASK_INIT) { - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // parallelize by src0 rows using ggml_rwkv_vec_dot_f32 - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - for (int ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; - - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; - - ggml_rwkv_vec_dot_f32(ne00, - (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), - (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); - } - } - - //int64_t t1 = ggml_rwkv_perf_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} -} - -static void ggml_rwkv_compute_forward_mul_mat_f16_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - GGML_RWKV_ASSERT(ne02 == ne12); - GGML_RWKV_ASSERT(ne03 == ne13); - GGML_RWKV_ASSERT(ne2 == ne12); - GGML_RWKV_ASSERT(ne3 == ne13); - - // TODO: we don't support permuted src0 - GGML_RWKV_ASSERT(nb00 == sizeof(ggml_rwkv_fp16_t)); - - // dst cannot be transposed or permuted - GGML_RWKV_ASSERT(nb0 == sizeof(float)); - GGML_RWKV_ASSERT(nb0 <= nb1); - GGML_RWKV_ASSERT(nb1 <= nb2); - GGML_RWKV_ASSERT(nb2 <= nb3); - - GGML_RWKV_ASSERT(ne0 == ne01); - GGML_RWKV_ASSERT(ne1 == ne11); - GGML_RWKV_ASSERT(ne2 == ne02); - GGML_RWKV_ASSERT(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - -#if defined(GGML_RWKV_USE_ACCELERATE) || defined(GGML_RWKV_USE_OPENBLAS) - if (ggml_rwkv_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - GGML_RWKV_ASSERT(nb10 == sizeof(float)); - - if (params->ith != 0) { - return; - } - - if (params->type == GGML_RWKV_TASK_INIT) { - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - float * const wdata = params->wdata; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - { - size_t id = 0; - for (int i01 = 0; i01 < ne01; ++i01) { - for (int i00 = 0; i00 < ne00; ++i00) { - wdata[id++] = GGML_RWKV_FP16_TO_FP32(*(ggml_rwkv_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); - } - } - } - - const float * x = wdata; - const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); - - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - - // zT = y * xT - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, - 0.0f, d, ne01); - } - } - - /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_rwkv_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ - - return; - } -#endif - - if (params->type == GGML_RWKV_TASK_INIT) { - ggml_rwkv_fp16_t * const wdata = params->wdata; - - size_t id = 0; - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - for (int i10 = 0; i10 < ne10; ++i10) { - wdata[id++] = GGML_RWKV_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); - } - } - } - } - - GGML_RWKV_ASSERT(id*sizeof(ggml_rwkv_fp16_t) <= params->wsize); - - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // fp16 -> half the size, so divide by 2 - // TODO: do not support transposed src1 - assert(nb10/2 == sizeof(ggml_rwkv_fp16_t)); - - // parallelize by src0 rows using ggml_rwkv_vec_dot_f16 - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_rwkv_fp16_t * wdata = params->wdata; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int i13 = i03; - const int i12 = i02; - - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; - - ggml_rwkv_fp16_t * src0_row = (ggml_rwkv_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - ggml_rwkv_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; - - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - - for (int ic = 0; ic < ne11; ++ic) { - ggml_rwkv_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); - } - } - - //int64_t t1 = ggml_rwkv_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} -} - -typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k); -typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k); -typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y); - -typedef struct { - dequantize_row_q_t dequantize_row_q; - quantize_row_q_t quantize_row_q; - vec_dot_q_t vec_dot_q; -} quantize_fns_t; - -static const quantize_fns_t quantize_fns[GGML_RWKV_TYPE_COUNT] = { - [GGML_RWKV_TYPE_Q4_0] = { - .dequantize_row_q = dequantize_row_q4_0, - .quantize_row_q = quantize_row_q4_0, - .vec_dot_q = ggml_rwkv_vec_dot_q4_0, - }, - [GGML_RWKV_TYPE_Q4_1] = { - .dequantize_row_q = dequantize_row_q4_1, - .quantize_row_q = quantize_row_q4_1, - .vec_dot_q = ggml_rwkv_vec_dot_q4_1, - }, - [GGML_RWKV_TYPE_Q4_1_O] = { - .dequantize_row_q = dequantize_row_q4_1_o, - .quantize_row_q = quantize_row_q4_1_o, - .vec_dot_q = NULL, - }, -}; - -static void ggml_rwkv_compute_forward_mul_mat_q_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - GGML_RWKV_ASSERT(ne02 == ne12); - GGML_RWKV_ASSERT(ne03 == ne13); - GGML_RWKV_ASSERT(ne2 == ne12); - GGML_RWKV_ASSERT(ne3 == ne13); - - const enum ggml_rwkv_type type = src0->type; - quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q; - vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; - - // we don't support permuted src0 or src1 - GGML_RWKV_ASSERT(nb00 == (int) GGML_RWKV_TYPE_SIZE[type]); - GGML_RWKV_ASSERT(nb10 == sizeof(float)); - - // dst cannot be transposed or permuted - GGML_RWKV_ASSERT(nb0 == sizeof(float)); - GGML_RWKV_ASSERT(nb0 <= nb1); - GGML_RWKV_ASSERT(nb1 <= nb2); - GGML_RWKV_ASSERT(nb2 <= nb3); - - GGML_RWKV_ASSERT(ne0 == ne01); - GGML_RWKV_ASSERT(ne1 == ne11); - GGML_RWKV_ASSERT(ne2 == ne02); - GGML_RWKV_ASSERT(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - -#if defined(GGML_RWKV_USE_ACCELERATE) || defined(GGML_RWKV_USE_OPENBLAS) - if (ggml_rwkv_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - if (params->ith != 0) { - return; - } - - if (params->type == GGML_RWKV_TASK_INIT) { - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - float * const wdata = params->wdata; - dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - { - size_t id = 0; - for (int i01 = 0; i01 < ne01; ++i01) { - dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); - id += ne00; - } - } - - const float * x = wdata; - const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); - - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - - // zT = y * xT - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, - 0.0f, d, ne01); - } - } - - //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_rwkv_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); - - return; - } -#endif - - if (params->type == GGML_RWKV_TASK_INIT) { - char * wdata = params->wdata; - const size_t row_size = ne10*GGML_RWKV_TYPE_SIZE[type]/GGML_RWKV_BLCK_SIZE[type]; - - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; - } - } - } - - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // parallelize by src0 rows using ggml_rwkv_vec_dot_q - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - void * wdata = params->wdata; - const size_t row_size = ne00*GGML_RWKV_TYPE_SIZE[type]/GGML_RWKV_BLCK_SIZE[type]; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int i13 = i03; - const int i12 = i02; - - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; - - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); - - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - - assert(ne00 % 32 == 0); - - for (int ic = 0; ic < ne11; ++ic) { - vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); - } - } - - //int64_t t1 = ggml_rwkv_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} -} - -static void ggml_rwkv_compute_forward_mul_mat_q4_1_o_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - GGML_RWKV_ASSERT(ne02 == ne12); - GGML_RWKV_ASSERT(ne03 == ne13); - GGML_RWKV_ASSERT(ne2 == ne12); - GGML_RWKV_ASSERT(ne3 == ne13); - - const enum ggml_rwkv_type type = src0->type; - - // we don't support permuted src0 or src1 - GGML_RWKV_ASSERT(nb00 == (int) GGML_RWKV_TYPE_SIZE[type]); - GGML_RWKV_ASSERT(nb10 == sizeof(float)); - - // dst cannot be transposed or permuted - GGML_RWKV_ASSERT(nb0 == sizeof(float)); - GGML_RWKV_ASSERT(nb0 <= nb1); - GGML_RWKV_ASSERT(nb1 <= nb2); - GGML_RWKV_ASSERT(nb2 <= nb3); - - GGML_RWKV_ASSERT(ne0 == ne01); - GGML_RWKV_ASSERT(ne1 == ne11); - GGML_RWKV_ASSERT(ne2 == ne02); - GGML_RWKV_ASSERT(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - -#if defined(GGML_RWKV_USE_ACCELERATE) || defined(GGML_RWKV_USE_OPENBLAS) - if (ggml_rwkv_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - if (params->ith != 0) { - return; - } - - if (params->type == GGML_RWKV_TASK_INIT) { - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - float * const wdata = params->wdata; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - { - size_t id = 0; - for (int i01 = 0; i01 < ne01; ++i01) { - dequantize_row_q4_1_o((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); - id += ne00; - } - } - - const float * x = wdata; - const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); - - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - - // zT = y * xT - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, - 0.0f, d, ne01); - } - } - - //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_rwkv_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); - - return; - } -#endif - - if (params->type == GGML_RWKV_TASK_INIT) { - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // parallelize by src0 rows using ggml_rwkv_vec_dot_f32 - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - -#if defined(__AVX2__) - float outlier_mask[QK]; - memset(outlier_mask, 0, QK * sizeof(float)); -#endif - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - -#if defined(__AVX2__) - for (int ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; - - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; - - const int block_count = ne00 / QK; - - const block_q4_1_o * row_blocks = (block_q4_1_o *) ((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03)); - - __m256 accum = _mm256_setzero_ps(); - - // Here we do fused dequantization and dot product. - for (int block_index = 0; block_index < block_count; block_index++) { - const float block_d = ggml_rwkv_half_to_float_reference(row_blocks[block_index].d); - const float block_m = ggml_rwkv_half_to_float_reference(row_blocks[block_index].m); - - // 0 .. 31 - const uint16_t outlier_index = row_blocks[block_index].outlier_index; - const float outlier_value = ggml_rwkv_half_to_float_reference(row_blocks[block_index].outlier_value); - - const uint8_t * restrict quant_nibbles = row_blocks[block_index].qs; - - // --- - - // Broadcast values to 8x element float32 vectors - const __m256 broadcasted_d = _mm256_broadcast_ss(&block_d); - const __m256 broadcasted_m = _mm256_broadcast_ss(&block_m); - const __m256 broadcasted_outlier_value = _mm256_broadcast_ss(&outlier_value); - - // Load 32x4-bit integers into 32x8-bit integers - const __m256i quant_bytes = bytesFromNibbles(quant_nibbles); - - // Convert to 16-bit int - const __m256i quant_shorts_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 0)); - const __m256i quant_shorts_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 1)); - - // Convert to 32-bit int and then to 32-bit float - const __m256 quant_floats_0 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 0))); - const __m256 quant_floats_1 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 1))); - const __m256 quant_floats_2 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 0))); - const __m256 quant_floats_3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 1))); - - // Dequantize to ~original weights - const __m256 weight_0 = _mm256_fmadd_ps(quant_floats_0, broadcasted_d, broadcasted_m); - const __m256 weight_1 = _mm256_fmadd_ps(quant_floats_1, broadcasted_d, broadcasted_m); - const __m256 weight_2 = _mm256_fmadd_ps(quant_floats_2, broadcasted_d, broadcasted_m); - const __m256 weight_3 = _mm256_fmadd_ps(quant_floats_3, broadcasted_d, broadcasted_m); - - // TODO This outlier handling is VERY slow - // Set outlier mask -- this should give 1 in the most significant bit - outlier_mask[outlier_index] = -1.0F; - // Load mask into vectors - const __m256 outlier_mask_0 = _mm256_load_ps(outlier_mask); - const __m256 outlier_mask_1 = _mm256_load_ps(outlier_mask + 8); - const __m256 outlier_mask_2 = _mm256_load_ps(outlier_mask + 16); - const __m256 outlier_mask_3 = _mm256_load_ps(outlier_mask + 24); - // Reset mask array to all zeroes for the next iteration - outlier_mask[outlier_index] = 0.0F; - - // Replace the weight at the index of the outlier - const __m256 weight_0_with_outlier = _mm256_blendv_ps(weight_0, broadcasted_outlier_value, outlier_mask_0); - const __m256 weight_1_with_outlier = _mm256_blendv_ps(weight_1, broadcasted_outlier_value, outlier_mask_1); - const __m256 weight_2_with_outlier = _mm256_blendv_ps(weight_2, broadcasted_outlier_value, outlier_mask_2); - const __m256 weight_3_with_outlier = _mm256_blendv_ps(weight_3, broadcasted_outlier_value, outlier_mask_3); - - // Load 32 floats of data of the second argument - const float * src1_data = (float *) ((char *) src1->data + (block_index * QK * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13)); - - const __m256 src1_0 = _mm256_load_ps(src1_data); - const __m256 src1_1 = _mm256_load_ps(src1_data + 8); - const __m256 src1_2 = _mm256_load_ps(src1_data + 16); - const __m256 src1_3 = _mm256_load_ps(src1_data + 24); - - // Multiply weights and values of the second argument element-wise; add to accumulator - accum = _mm256_fmadd_ps(src1_0, weight_0_with_outlier, accum); - accum = _mm256_fmadd_ps(src1_1, weight_1_with_outlier, accum); - accum = _mm256_fmadd_ps(src1_2, weight_2_with_outlier, accum); - accum = _mm256_fmadd_ps(src1_3, weight_3_with_outlier, accum); - } - - // Add elements of accumulator - __m128 res = _mm256_extractf128_ps(accum, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(accum)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res )); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - - *((float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3))) = _mm_cvtss_f32(res); - } -#else - float * const wdata = (float *) ((char *) params->wdata + (i01 * nb01 + i02 * nb02 + i03 * nb03)); - - dequantize_row_q4_1_o((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03), wdata, ne00); - - for (int ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; - - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; - - ggml_rwkv_vec_dot_f32( - ne00, - (float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3)), - wdata, - (float *) ((char *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13)) - ); - } -#endif - } -} - -static void ggml_rwkv_compute_forward_mul_mat( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - { - ggml_rwkv_compute_forward_mul_mat_q_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_1_O: - { - ggml_rwkv_compute_forward_mul_mat_q4_1_o_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_F16: - { - ggml_rwkv_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_mul_mat_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } - -#if 0 - if (src0->type == GGML_RWKV_TYPE_F16 || src0->type == GGML_RWKV_TYPE_Q4_1) { - static int first = 8; - printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); - printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); - printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - if (first) { - --first; - } else { - for (int k = 0; k < dst->ne[1]; ++k) { - for (int j = 0; j < dst->ne[0]/16; ++j) { - for (int i = 0; i < 16; ++i) { - printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - } - printf("\n"); - } - printf("\n"); - } - printf("\n"); - exit(0); - } - } else { - printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); - printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); - printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - } -#endif -} - -// ggml_rwkv_compute_forward_scale - -static void ggml_rwkv_compute_forward_scale_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(src0)); - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(dst)); - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(src0, dst)); - GGML_RWKV_ASSERT(ggml_rwkv_is_scalar(src1)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // scale factor - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_rwkv_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_rwkv_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), v); - } -} - -static void ggml_rwkv_compute_forward_scale( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_scale_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_cpy - -static void ggml_rwkv_compute_forward_cpy( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - ggml_rwkv_compute_forward_dup(params, src0, dst); -} - -// ggml_rwkv_compute_forward_reshape - -static void ggml_rwkv_compute_forward_reshape( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(src0); - UNUSED(dst); -} - -// ggml_rwkv_compute_forward_view - -static void ggml_rwkv_compute_forward_view( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0) { - // NOP - UNUSED(params); - UNUSED(src0); -} - -// ggml_rwkv_compute_forward_permute - -static void ggml_rwkv_compute_forward_permute( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0) { - // NOP - UNUSED(params); - UNUSED(src0); -} - -// ggml_rwkv_compute_forward_transpose - -static void ggml_rwkv_compute_forward_transpose( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0) { - // NOP - UNUSED(params); - UNUSED(src0); -} - -// ggml_rwkv_compute_forward_get_rows - -static void ggml_rwkv_compute_forward_get_rows_q( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_rwkv_nelements(src1); - const enum ggml_rwkv_type type = src0->type; - dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; - - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == GGML_RWKV_TYPE_SIZE[type]); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - dequantize_row_q( - (const void *) ((char *) src0->data + r*src0->nb[1]), - (float *) ((char *) dst->data + i*dst->nb[1]), nc); - } -} - -static void ggml_rwkv_compute_forward_get_rows_f16( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_rwkv_nelements(src1); - - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(ggml_rwkv_fp16_t)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - for (int j = 0; j < nc; ++j) { - ggml_rwkv_fp16_t v = ((ggml_rwkv_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_RWKV_FP16_TO_FP32(v); - } - } -} - -static void ggml_rwkv_compute_forward_get_rows_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_rwkv_nelements(src1); - - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - ggml_rwkv_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i*dst->nb[1]), - (float *) ((char *) src0->data + r*src0->nb[1])); - } -} - -static void ggml_rwkv_compute_forward_get_rows( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - { - ggml_rwkv_compute_forward_get_rows_q(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_F16: - { - ggml_rwkv_compute_forward_get_rows_f16(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_get_rows_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } - - //static bool first = true; - //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - //if (first) { - // first = false; - //} else { - // for (int k = 0; k < dst->ne[1]; ++k) { - // for (int j = 0; j < dst->ne[0]/16; ++j) { - // for (int i = 0; i < 16; ++i) { - // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // exit(0); - //} -} - -// ggml_rwkv_compute_forward_diag_mask_inf - -static void ggml_rwkv_compute_forward_diag_mask_inf_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(src1->type == GGML_RWKV_TYPE_I32); - assert(ggml_rwkv_nelements(src1) == 1); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n_past = ((int32_t *) src1->data)[0]; - - // TODO: handle transposed/permuted matrices - - const int n = ggml_rwkv_nrows(src0); - const int nc = src0->ne[0]; - const int nr = src0->ne[1]; - const int nz = n/nr; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int k = 0; k < nz; k++) { - for (int j = 0; j < nr; j++) { - for (int i = n_past; i < nc; i++) { - if (i > n_past + j) { - *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY; - } - } - } - } -} - -static void ggml_rwkv_compute_forward_diag_mask_inf( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_diag_mask_inf_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_soft_max - -static void ggml_rwkv_compute_forward_soft_max_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(src0)); - GGML_RWKV_ASSERT(ggml_rwkv_is_contiguous(dst)); - GGML_RWKV_ASSERT(ggml_rwkv_are_same_shape(src0, dst)); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // TODO: handle transposed/permuted matrices - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_rwkv_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float *p = (float *)((char *) dst->data + i1*dst->nb[1]); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(p[i])); - } -#endif - - float max = -INFINITY; - ggml_rwkv_vec_max_f32(nc, &max, p); - - ggml_rwkv_float sum = 0.0; - - uint16_t scvt; - for (int i = 0; i < nc; i++) { - if (p[i] == -INFINITY) { - p[i] = 0.0f; - } else { - //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); - ggml_rwkv_fp16_t s = GGML_RWKV_FP32_TO_FP16(p[i] - max); - memcpy(&scvt, &s, sizeof(scvt)); - const float val = GGML_RWKV_FP16_TO_FP32(table_exp_f16[scvt]); - sum += (ggml_rwkv_float)val; - p[i] = val; - } - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_rwkv_vec_scale_f32(nc, p, sum); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(p[i])); - assert(!isinf(p[i])); - } -#endif - } -} - -static void ggml_rwkv_compute_forward_soft_max( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_soft_max_f32(params, src0, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_F16: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_rope - -static void ggml_rwkv_compute_forward_rope_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(src1->type == GGML_RWKV_TYPE_I32); - assert(ggml_rwkv_nelements(src1) == 3); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - - //const int ne0 = src0->ne[0]; - const int ne1 = src0->ne[1]; - const int ne2 = src0->ne[2]; - const int ne3 = src0->ne[3]; - - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; - const int nb3 = src0->nb[3]; - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - assert(nb0 == sizeof(float)); - - // TODO: optimize - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int p = (mode == 0 ? n_past + i2 : i2); - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < n_dims; i0 += 2) { - const float theta = powf(10000.0, ((float)-i0)/n_dims); - - const float cos_theta = cosf(p*theta); - const float sin_theta = sinf(p*theta); - - const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } - } - } -} - -static void ggml_rwkv_compute_forward_rope_f16( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - assert(params->ith == 0); - assert(src1->type == GGML_RWKV_TYPE_I32); - assert(ggml_rwkv_nelements(src1) == 3); - - if (params->type == GGML_RWKV_TASK_INIT || params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - - //const int ne0 = src0->ne[0]; - const int ne1 = src0->ne[1]; - const int ne2 = src0->ne[2]; - const int ne3 = src0->ne[3]; - - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; - const int nb3 = src0->nb[3]; - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - assert(nb0 == sizeof(ggml_rwkv_fp16_t)); - - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int p = (mode == 0 ? n_past + i2 : i2); - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < n_dims; i0 += 2) { - const float theta = powf(10000.0, ((float)-i0)/n_dims); - - const float cos_theta = cosf(p*theta); - const float sin_theta = sinf(p*theta); - - const ggml_rwkv_fp16_t * const src = (ggml_rwkv_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - ggml_rwkv_fp16_t * dst_data = (ggml_rwkv_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = ggml_rwkv_fp16_to_fp32(src[0]); - const float x1 = ggml_rwkv_fp16_to_fp32(src[1]); - - dst_data[0] = ggml_rwkv_fp32_to_fp16(x0*cos_theta - x1*sin_theta); - dst_data[1] = ggml_rwkv_fp32_to_fp16(x0*sin_theta + x1*cos_theta); - } - } - } - } -} - -static void ggml_rwkv_compute_forward_rope( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F16: - { - ggml_rwkv_compute_forward_rope_f16(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_rope_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_conv_1d_1s - -static void ggml_rwkv_compute_forward_conv_1d_1s_f16_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(src0->type == GGML_RWKV_TYPE_F16); - GGML_RWKV_ASSERT(src1->type == GGML_RWKV_TYPE_F32); - GGML_RWKV_ASSERT( dst->type == GGML_RWKV_TYPE_F32); - - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - //const int ne12 = src1->ne[2]; - //const int ne13 = src1->ne[3]; - - //const int ne0 = dst->ne[0]; - //const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_rwkv_up32(ne01); - - GGML_RWKV_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes - GGML_RWKV_ASSERT(nb00 == sizeof(ggml_rwkv_fp16_t)); - GGML_RWKV_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_RWKV_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) - { - ggml_rwkv_fp16_t * const wdata = (ggml_rwkv_fp16_t *) params->wdata + 0; - - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const ggml_rwkv_fp16_t * const src = (ggml_rwkv_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_rwkv_fp16_t * dst_data = wdata + i02*ew0*ne00; - for (int i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - ggml_rwkv_fp16_t * const wdata = (ggml_rwkv_fp16_t *) params->wdata + ne02*ew0*ne00; - - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_rwkv_fp16_t * dst_data = wdata; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = GGML_RWKV_FP32_TO_FP16(src[i10]); - } - } - } - - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_rwkv_vec_dot_f16(ew0, &v, - (ggml_rwkv_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (ggml_rwkv_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; - } - } - } -} - -static void ggml_rwkv_compute_forward_conv_1d_1s_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(src0->type == GGML_RWKV_TYPE_F32); - GGML_RWKV_ASSERT(src1->type == GGML_RWKV_TYPE_F32); - GGML_RWKV_ASSERT( dst->type == GGML_RWKV_TYPE_F32); - - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - //const int ne12 = src1->ne[2]; - //const int ne13 = src1->ne[3]; - - //const int ne0 = dst->ne[0]; - //const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_rwkv_up32(ne01); - - GGML_RWKV_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes - GGML_RWKV_ASSERT(nb00 == sizeof(float)); - GGML_RWKV_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_RWKV_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) - { - float * const wdata = (float *) params->wdata + 0; - - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; - for (int i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; - - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; - } - } - } - - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_rwkv_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; - } - } - } -} - -static void ggml_rwkv_compute_forward_conv_1d_1s( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F16: - { - ggml_rwkv_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_conv_1d_2s - -static void ggml_rwkv_compute_forward_conv_1d_2s_f16_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(src0->type == GGML_RWKV_TYPE_F16); - GGML_RWKV_ASSERT(src1->type == GGML_RWKV_TYPE_F32); - GGML_RWKV_ASSERT( dst->type == GGML_RWKV_TYPE_F32); - - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - //const int ne12 = src1->ne[2]; - //const int ne13 = src1->ne[3]; - - //const int ne0 = dst->ne[0]; - //const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_rwkv_up32(ne01); - - GGML_RWKV_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes - GGML_RWKV_ASSERT(nb00 == sizeof(ggml_rwkv_fp16_t)); - GGML_RWKV_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_RWKV_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) - { - ggml_rwkv_fp16_t * const wdata = (ggml_rwkv_fp16_t *) params->wdata + 0; - - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const ggml_rwkv_fp16_t * const src = (ggml_rwkv_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_rwkv_fp16_t * dst_data = wdata + i02*ew0*ne00; - for (int i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - ggml_rwkv_fp16_t * const wdata = (ggml_rwkv_fp16_t *) params->wdata + ne02*ew0*ne00; - - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_rwkv_fp16_t * dst_data = wdata; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = GGML_RWKV_FP32_TO_FP16(src[i10]); - } - } - } - - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_rwkv_vec_dot_f16(ew0, &v, - (ggml_rwkv_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (ggml_rwkv_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; - } - } - } -} - -static void ggml_rwkv_compute_forward_conv_1d_2s_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - GGML_RWKV_ASSERT(src0->type == GGML_RWKV_TYPE_F32); - GGML_RWKV_ASSERT(src1->type == GGML_RWKV_TYPE_F32); - GGML_RWKV_ASSERT( dst->type == GGML_RWKV_TYPE_F32); - - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - //const int ne12 = src1->ne[2]; - //const int ne13 = src1->ne[3]; - - //const int ne0 = dst->ne[0]; - //const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_rwkv_up32(ne01); - - GGML_RWKV_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes - GGML_RWKV_ASSERT(nb00 == sizeof(float)); - GGML_RWKV_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_RWKV_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) - { - float * const wdata = (float *) params->wdata + 0; - - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; - for (int i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; - - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; - } - } - } - - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_rwkv_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; - } - } - } -} - -static void ggml_rwkv_compute_forward_conv_1d_2s( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * src0, - const struct ggml_rwkv_tensor * src1, - struct ggml_rwkv_tensor * dst) { - switch (src0->type) { - case GGML_RWKV_TYPE_F16: - { - ggml_rwkv_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_flash_attn - -static void ggml_rwkv_compute_forward_flash_attn_f32( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * q, - const struct ggml_rwkv_tensor * k, - const struct ggml_rwkv_tensor * v, - const bool masked, - struct ggml_rwkv_tensor * dst) { - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int neq0 = q->ne[0]; - const int neq1 = q->ne[1]; - const int neq2 = q->ne[2]; - const int neq3 = q->ne[3]; - - const int nek0 = k->ne[0]; - const int nek1 = k->ne[1]; - //const int nek2 = k->ne[2]; - //const int nek3 = k->ne[3]; - - //const int nev0 = v->ne[0]; - const int nev1 = v->ne[1]; - //const int nev2 = v->ne[2]; - //const int nev3 = v->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - - const int nbk0 = k->nb[0]; - const int nbk1 = k->nb[1]; - const int nbk2 = k->nb[2]; - const int nbk3 = k->nb[3]; - - const int nbq0 = q->nb[0]; - const int nbq1 = q->nb[1]; - const int nbq2 = q->nb[2]; - const int nbq3 = q->nb[3]; - - const int nbv0 = v->nb[0]; - const int nbv1 = v->nb[1]; - const int nbv2 = v->nb[2]; - const int nbv3 = v->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int D = neq0; - const int N = neq1; - const int P = nek1 - N; - const int M = P + N; - - const int Mup = ggml_rwkv_up(M, GGML_RWKV_SOFT_MAX_UNROLL); - - GGML_RWKV_ASSERT(ne0 == D); - GGML_RWKV_ASSERT(ne1 == N); - GGML_RWKV_ASSERT(P >= 0); - - GGML_RWKV_ASSERT(nbq0 == sizeof(float)); - GGML_RWKV_ASSERT(nbk0 == sizeof(float)); - GGML_RWKV_ASSERT(nbv0 == sizeof(float)); - - GGML_RWKV_ASSERT(neq0 == D); - GGML_RWKV_ASSERT(nek0 == D); - GGML_RWKV_ASSERT(nev1 == D); - - GGML_RWKV_ASSERT(neq1 == N); - GGML_RWKV_ASSERT(nek1 == N + P); - GGML_RWKV_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_RWKV_ASSERT(nb0 == sizeof(float)); - GGML_RWKV_ASSERT(nb0 <= nb1); - GGML_RWKV_ASSERT(nb1 <= nb2); - GGML_RWKV_ASSERT(nb2 <= nb3); - - if (params->type == GGML_RWKV_TASK_INIT) { - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_rwkv_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - for (int ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_rwkv_vec_dot_f32(neq0, - S + i1, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - - // scale - ggml_rwkv_vec_scale_f32(nek1, S, scale); - - if (masked) { - for (int i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } - } - - // softmax - { - float max = -INFINITY; - ggml_rwkv_vec_max_f32(M, &max, S); - - ggml_rwkv_float sum = 0.0; - { -#ifdef GGML_RWKV_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_rwkv_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_RWKV_SOFT_MAX_UNROLL]; - ggml_rwkv_float sump[GGML_RWKV_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_RWKV_SOFT_MAX_UNROLL) { - float * SS = S + i; - - for (int j = 0; j < GGML_RWKV_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_rwkv_fp16_t s = GGML_RWKV_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_RWKV_FP16_TO_FP32(table_exp_f16[scvt[j]]); - sump[j] += (ggml_rwkv_float)val; - SS[j] = val; - } - } - } - - for (int i = 0; i < GGML_RWKV_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_rwkv_vec_scale_f32(M, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - for (int ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - ggml_rwkv_vec_dot_f32(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - S); - } - } -} - -static void ggml_rwkv_compute_forward_flash_attn_f16( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * q, - const struct ggml_rwkv_tensor * k, - const struct ggml_rwkv_tensor * v, - const bool masked, - struct ggml_rwkv_tensor * dst) { - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int neq0 = q->ne[0]; - const int neq1 = q->ne[1]; - const int neq2 = q->ne[2]; - const int neq3 = q->ne[3]; - - const int nek0 = k->ne[0]; - const int nek1 = k->ne[1]; - //const int nek2 = k->ne[2]; - //const int nek3 = k->ne[3]; - - //const int nev0 = v->ne[0]; - const int nev1 = v->ne[1]; - //const int nev2 = v->ne[2]; - //const int nev3 = v->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - - const int nbk0 = k->nb[0]; - const int nbk1 = k->nb[1]; - const int nbk2 = k->nb[2]; - const int nbk3 = k->nb[3]; - - const int nbq0 = q->nb[0]; - const int nbq1 = q->nb[1]; - const int nbq2 = q->nb[2]; - const int nbq3 = q->nb[3]; - - const int nbv0 = v->nb[0]; - const int nbv1 = v->nb[1]; - const int nbv2 = v->nb[2]; - const int nbv3 = v->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int D = neq0; - const int N = neq1; - const int P = nek1 - N; - const int M = P + N; - - const int Mup = ggml_rwkv_up(M, GGML_RWKV_SOFT_MAX_UNROLL); - - GGML_RWKV_ASSERT(ne0 == D); - GGML_RWKV_ASSERT(ne1 == N); - GGML_RWKV_ASSERT(P >= 0); - - GGML_RWKV_ASSERT(nbq0 == sizeof(ggml_rwkv_fp16_t)); - GGML_RWKV_ASSERT(nbk0 == sizeof(ggml_rwkv_fp16_t)); - GGML_RWKV_ASSERT(nbv0 == sizeof(ggml_rwkv_fp16_t)); - - GGML_RWKV_ASSERT(neq0 == D); - GGML_RWKV_ASSERT(nek0 == D); - GGML_RWKV_ASSERT(nev1 == D); - - GGML_RWKV_ASSERT(neq1 == N); - GGML_RWKV_ASSERT(nek1 == N + P); - GGML_RWKV_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_RWKV_ASSERT(nb0 == sizeof(float)); - GGML_RWKV_ASSERT(nb0 <= nb1); - GGML_RWKV_ASSERT(nb1 <= nb2); - GGML_RWKV_ASSERT(nb2 <= nb3); - - if (params->type == GGML_RWKV_TASK_INIT) { - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_rwkv_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - if (GGML_RWKV_VEC_DOT_UNROLL > 2 || nek1 % GGML_RWKV_VEC_DOT_UNROLL != 0) { - for (int ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_rwkv_vec_dot_f16(neq0, - S + i1, - (ggml_rwkv_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_rwkv_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } else { - for (int ic = 0; ic < nek1; ic += GGML_RWKV_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_rwkv_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_rwkv_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } - - // scale - ggml_rwkv_vec_scale_f32(nek1, S, scale); - - if (masked) { - for (int i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } - } - - // softmax - { - float max = -INFINITY; - ggml_rwkv_vec_max_f32(M, &max, S); - - ggml_rwkv_float sum = 0.0; - { -#ifdef GGML_RWKV_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_rwkv_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_RWKV_SOFT_MAX_UNROLL]; - ggml_rwkv_float sump[GGML_RWKV_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_RWKV_SOFT_MAX_UNROLL) { - float * SS = S + i; - - for (int j = 0; j < GGML_RWKV_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_rwkv_fp16_t s = GGML_RWKV_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_RWKV_FP16_TO_FP32(table_exp_f16[scvt[j]]); - sump[j] += (ggml_rwkv_float)val; - SS[j] = val; - } - } - } - - for (int i = 0; i < GGML_RWKV_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_rwkv_vec_scale_f32(M, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - ggml_rwkv_fp16_t * S16 = (ggml_rwkv_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int i = 0; i < M; i++) { - S16[i] = GGML_RWKV_FP32_TO_FP16(S[i]); - } - - if (GGML_RWKV_VEC_DOT_UNROLL == 1 || (nev1 % GGML_RWKV_VEC_DOT_UNROLL != 0)) { - for (int ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - ggml_rwkv_vec_dot_f16(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_rwkv_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - S16); - } - } else { - for (int ic = 0; ic < nev1; ic += GGML_RWKV_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - ggml_rwkv_vec_dot_f16_unroll(nek1, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - S16); - } - } - } -} - -static void ggml_rwkv_compute_forward_flash_attn( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * q, - const struct ggml_rwkv_tensor * k, - const struct ggml_rwkv_tensor * v, - const bool masked, - struct ggml_rwkv_tensor * dst) { - switch (q->type) { - case GGML_RWKV_TYPE_F16: - { - ggml_rwkv_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); - } break; - case GGML_RWKV_TYPE_F32: - { - ggml_rwkv_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -// ggml_rwkv_compute_forward_flash_ff - -static void ggml_rwkv_compute_forward_flash_ff_f16( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * a, // F16 - const struct ggml_rwkv_tensor * b0, // F16 fc_w - const struct ggml_rwkv_tensor * b1, // F32 fc_b - const struct ggml_rwkv_tensor * c0, // F16 proj_w - const struct ggml_rwkv_tensor * c1, // F32 proj_b - struct ggml_rwkv_tensor * dst) { - int64_t t0 = ggml_rwkv_perf_time_us(); - UNUSED(t0); - - const int nea0 = a->ne[0]; - const int nea1 = a->ne[1]; - const int nea2 = a->ne[2]; - const int nea3 = a->ne[3]; - - const int neb00 = b0->ne[0]; - const int neb01 = b0->ne[1]; - //const int neb02 = b0->ne[2]; - //const int neb03 = b0->ne[3]; - - const int neb10 = b1->ne[0]; - const int neb11 = b1->ne[1]; - //const int neb12 = b1->ne[2]; - //const int neb13 = b1->ne[3]; - - const int nec00 = c0->ne[0]; - const int nec01 = c0->ne[1]; - //const int nec02 = c0->ne[2]; - //const int nec03 = c0->ne[3]; - - const int nec10 = c1->ne[0]; - const int nec11 = c1->ne[1]; - //const int nec12 = c1->ne[2]; - //const int nec13 = c1->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - - const int nba0 = a->nb[0]; - const int nba1 = a->nb[1]; - const int nba2 = a->nb[2]; - const int nba3 = a->nb[3]; - - const int nbb00 = b0->nb[0]; - const int nbb01 = b0->nb[1]; - const int nbb02 = b0->nb[2]; - const int nbb03 = b0->nb[3]; - - const int nbb10 = b1->nb[0]; - //const int nbb11 = b1->nb[1]; - //const int nbb12 = b1->nb[2]; - //const int nbb13 = b1->nb[3]; - - const int nbc00 = c0->nb[0]; - const int nbc01 = c0->nb[1]; - const int nbc02 = c0->nb[2]; - const int nbc03 = c0->nb[3]; - - const int nbc10 = c1->nb[0]; - //const int nbc11 = c1->nb[1]; - //const int nbc12 = c1->nb[2]; - //const int nbc13 = c1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int D = nea0; - //const int N = nea1; - const int M = neb01; - - GGML_RWKV_ASSERT(ne0 == nea0); - GGML_RWKV_ASSERT(ne1 == nea1); - GGML_RWKV_ASSERT(ne2 == nea2); - - GGML_RWKV_ASSERT(nba0 == sizeof(ggml_rwkv_fp16_t)); - GGML_RWKV_ASSERT(nbb00 == sizeof(ggml_rwkv_fp16_t)); - GGML_RWKV_ASSERT(nbb10 == sizeof(float)); - GGML_RWKV_ASSERT(nbc00 == sizeof(ggml_rwkv_fp16_t)); - GGML_RWKV_ASSERT(nbc10 == sizeof(float)); - - GGML_RWKV_ASSERT(neb00 == D); - GGML_RWKV_ASSERT(neb01 == M); - GGML_RWKV_ASSERT(neb10 == M); - GGML_RWKV_ASSERT(neb11 == 1); - - GGML_RWKV_ASSERT(nec00 == M); - GGML_RWKV_ASSERT(nec01 == D); - GGML_RWKV_ASSERT(nec10 == D); - GGML_RWKV_ASSERT(nec11 == 1); - - // dst cannot be transposed or permuted - GGML_RWKV_ASSERT(nb0 == sizeof(float)); - GGML_RWKV_ASSERT(nb0 <= nb1); - GGML_RWKV_ASSERT(nb1 <= nb2); - GGML_RWKV_ASSERT(nb2 <= nb3); - - if (params->type == GGML_RWKV_TASK_INIT) { - return; - } - - if (params->type == GGML_RWKV_TASK_FINALIZE) { - return; - } - - // parallelize by a rows using ggml_rwkv_vec_dot_f32 - - // total rows in a - const int nr = nea1*nea2*nea3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // a indices - const int ia3 = ir/(nea2*nea1); - const int ia2 = (ir - ia3*nea2*nea1)/nea1; - const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); - - float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); - - for (int ic = 0; ic < neb01; ++ic) { - // b0 indices - const int ib03 = ia3; - const int ib02 = ia2; - const int ib01 = ic; - - // S indices - const int i1 = ib01; - - ggml_rwkv_vec_dot_f16(nea0, - S + i1, - (ggml_rwkv_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), - (ggml_rwkv_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); - } - - ggml_rwkv_vec_add_f32(neb01, S, S, (float *) b1->data); - //ggml_rwkv_vec_gelu_f32(neb01, S, S); - - ggml_rwkv_fp16_t * S16 = (ggml_rwkv_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); - - for (int i = 0; i < M; i++) { - S16[i] = GGML_RWKV_FP32_TO_FP16(S[i]); - } - - ggml_rwkv_vec_gelu_f16(neb01, S16, S16); - - { - // dst indices - const int i1 = ia1; - const int i2 = ia2; - const int i3 = ia3; - - for (int ic = 0; ic < nec01; ++ic) { - - ggml_rwkv_vec_dot_f16(neb01, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_rwkv_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), - S16); - } - - ggml_rwkv_vec_add_f32(nec01, - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) c1->data); - } - } -} - -static void ggml_rwkv_compute_forward_flash_ff( - const struct ggml_rwkv_compute_params * params, - const struct ggml_rwkv_tensor * a, - const struct ggml_rwkv_tensor * b0, - const struct ggml_rwkv_tensor * b1, - const struct ggml_rwkv_tensor * c0, - const struct ggml_rwkv_tensor * c1, - struct ggml_rwkv_tensor * dst) { - switch (b0->type) { - case GGML_RWKV_TYPE_F16: - { - ggml_rwkv_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); - } break; - case GGML_RWKV_TYPE_F32: - { - GGML_RWKV_ASSERT(false); // TODO - } break; - case GGML_RWKV_TYPE_Q4_0: - case GGML_RWKV_TYPE_Q4_1: - case GGML_RWKV_TYPE_Q4_1_O: - case GGML_RWKV_TYPE_I8: - case GGML_RWKV_TYPE_I16: - case GGML_RWKV_TYPE_I32: - case GGML_RWKV_TYPE_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -///////////////////////////////// - -static void ggml_rwkv_compute_forward(struct ggml_rwkv_compute_params * params, struct ggml_rwkv_tensor * tensor) { - GGML_RWKV_ASSERT(params); - - switch (tensor->op) { - case GGML_RWKV_OP_DUP: - { - ggml_rwkv_compute_forward_dup(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_ADD: - { - ggml_rwkv_compute_forward_add(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_SUB: - { - ggml_rwkv_compute_forward_sub(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_MUL: - { - ggml_rwkv_compute_forward_mul(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_DIV: - { - ggml_rwkv_compute_forward_div(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_SQR: - { - ggml_rwkv_compute_forward_sqr(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_SQRT: - { - ggml_rwkv_compute_forward_sqrt(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_SUM: - { - ggml_rwkv_compute_forward_sum(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_MEAN: - { - ggml_rwkv_compute_forward_mean(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_REPEAT: - { - ggml_rwkv_compute_forward_repeat(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_ABS: - { - ggml_rwkv_compute_forward_abs(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_SGN: - { - ggml_rwkv_compute_forward_sgn(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_NEG: - { - ggml_rwkv_compute_forward_neg(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_EXP: - { - ggml_rwkv_compute_forward_exp(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_1_MINUS_X: - { - ggml_rwkv_compute_forward_1_minus_x(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_MAX: - { - ggml_rwkv_compute_forward_max(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_STEP: - { - ggml_rwkv_compute_forward_step(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_RELU: - { - ggml_rwkv_compute_forward_relu(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_GELU: - { - ggml_rwkv_compute_forward_gelu(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_SIGMOID: - { - ggml_rwkv_compute_forward_sigmoid(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_SILU: - { - ggml_rwkv_compute_forward_silu(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_NORM: - { - ggml_rwkv_compute_forward_norm(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_RMS_NORM: - { - ggml_rwkv_compute_forward_rms_norm(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_MUL_MAT: - { - ggml_rwkv_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_SCALE: - { - ggml_rwkv_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_CPY: - { - ggml_rwkv_compute_forward_cpy(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_RESHAPE: - { - ggml_rwkv_compute_forward_reshape(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_VIEW: - { - ggml_rwkv_compute_forward_view(params, tensor->src0); - } break; - case GGML_RWKV_OP_PERMUTE: - { - ggml_rwkv_compute_forward_permute(params, tensor->src0); - } break; - case GGML_RWKV_OP_TRANSPOSE: - { - ggml_rwkv_compute_forward_transpose(params, tensor->src0); - } break; - case GGML_RWKV_OP_GET_ROWS: - { - ggml_rwkv_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_DIAG_MASK_INF: - { - ggml_rwkv_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_SOFT_MAX: - { - ggml_rwkv_compute_forward_soft_max(params, tensor->src0, tensor); - } break; - case GGML_RWKV_OP_ROPE: - { - ggml_rwkv_compute_forward_rope(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_CONV_1D_1S: - { - ggml_rwkv_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_CONV_1D_2S: - { - ggml_rwkv_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_RWKV_OP_FLASH_ATTN: - { - int32_t t = ggml_rwkv_get_i32_1d(tensor->opt[1], 0); - GGML_RWKV_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - ggml_rwkv_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor); - } break; - case GGML_RWKV_OP_FLASH_FF: - { - ggml_rwkv_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); - } break; - case GGML_RWKV_OP_NONE: - { - // nop - } break; - case GGML_RWKV_OP_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -//////////////////////////////////////////////////////////////////////////////// - -static void ggml_rwkv_compute_backward(struct ggml_rwkv_context * ctx, struct ggml_rwkv_tensor * tensor, bool inplace) { - struct ggml_rwkv_tensor * src0 = tensor->src0; - struct ggml_rwkv_tensor * src1 = tensor->src1; - - switch (tensor->op) { - case GGML_RWKV_OP_DUP: - { - if (src0->grad) { - src0->grad = ggml_rwkv_add_impl(ctx, src0->grad, tensor->grad, inplace); - } - } break; - case GGML_RWKV_OP_ADD: - { - if (src0->grad) { - src0->grad = ggml_rwkv_add_impl(ctx, src0->grad, tensor->grad, inplace); - } - if (src1->grad) { - src1->grad = ggml_rwkv_add_impl(ctx, src1->grad, tensor->grad, inplace); - } - } break; - case GGML_RWKV_OP_SUB: - { - if (src0->grad) { - src0->grad = ggml_rwkv_add_impl(ctx, src0->grad, tensor->grad, inplace); - } - if (src1->grad) { - src1->grad = ggml_rwkv_sub_impl(ctx, src1->grad, tensor->grad, inplace); - } - } break; - case GGML_RWKV_OP_MUL: - { - if (src0->grad) { - src0->grad = - ggml_rwkv_add_impl(ctx, - src0->grad, - ggml_rwkv_mul(ctx, src1, tensor->grad), - inplace); - } - if (src1->grad) { - src1->grad = - ggml_rwkv_add_impl(ctx, - src1->grad, - ggml_rwkv_mul(ctx, src0, tensor->grad), - inplace); - } - } break; - case GGML_RWKV_OP_DIV: - { - if (src0->grad) { - src0->grad = - ggml_rwkv_add_impl(ctx, - src0->grad, - ggml_rwkv_div(ctx, tensor->grad, src1), - inplace); - } - if (src1->grad) { - src1->grad = - ggml_rwkv_sub_impl(ctx, - src1->grad, - ggml_rwkv_mul(ctx, - tensor->grad, - ggml_rwkv_div(ctx, tensor, src1)), - inplace); - } - } break; - case GGML_RWKV_OP_SQR: - { - if (src0->grad) { - src0->grad = - ggml_rwkv_add_impl(ctx, - src0->grad, - ggml_rwkv_mul(ctx, - ggml_rwkv_mul(ctx, src0, tensor->grad), - ggml_rwkv_repeat(ctx, ggml_rwkv_new_f32(ctx, 2.0f), src0)), - inplace); - } - } break; - case GGML_RWKV_OP_SQRT: - { - if (src0->grad) { - src0->grad = - ggml_rwkv_add_impl(ctx, - src0->grad, - ggml_rwkv_div(ctx, - ggml_rwkv_repeat(ctx, ggml_rwkv_new_f32(ctx, 0.5f), tensor), - tensor), - inplace); - } - } break; - case GGML_RWKV_OP_SUM: - { - if (src0->grad) { - src0->grad = - ggml_rwkv_add_impl(ctx, - src0->grad, - ggml_rwkv_repeat(ctx, tensor->grad, src0->grad), - inplace); - } - } break; - case GGML_RWKV_OP_MEAN: - { - GGML_RWKV_ASSERT(false); // TODO: implement - } break; - case GGML_RWKV_OP_REPEAT: - { - if (src0->grad) { - src0->grad = - ggml_rwkv_add_impl(ctx, - src0->grad, - ggml_rwkv_sum(ctx, tensor->grad), - inplace); - } - } break; - case GGML_RWKV_OP_ABS: - { - if (src0->grad) { - src0->grad = - ggml_rwkv_add_impl(ctx, - src0->grad, - ggml_rwkv_mul(ctx, - ggml_rwkv_sgn(ctx, src0), - tensor->grad), - inplace); - } - } break; - case GGML_RWKV_OP_SGN: - { - if (src0->grad) { - // noop - } - } break; - case GGML_RWKV_OP_NEG: - { - if (src0->grad) { - src0->grad = ggml_rwkv_sub_impl(ctx, src0->grad, tensor->grad, inplace); - } - } break; - case GGML_RWKV_OP_EXP: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_1_MINUS_X: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_MAX: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_STEP: - { - if (src0->grad) { - // noop - } - } break; - case GGML_RWKV_OP_RELU: - { - if (src0->grad) { - src0->grad = ggml_rwkv_sub_impl(ctx, - src0->grad, - ggml_rwkv_mul(ctx, - ggml_rwkv_step(ctx, src0), - tensor->grad), - inplace); - } - } break; - case GGML_RWKV_OP_GELU: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_SIGMOID: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_SILU: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_NORM: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_RMS_NORM: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_MUL_MAT: - { - if (src0->grad) { - // TODO: this requires outer product - ggml_rwkv_out_prod(ctx, src1, tensor->grad); - GGML_RWKV_ASSERT(false); - } - if (src1->grad) { - src1->grad = - ggml_rwkv_add_impl(ctx, - src1->grad, - // TODO: fix transpose, the node will break the graph connections - ggml_rwkv_mul_mat(ctx, ggml_rwkv_transpose(ctx, src0), tensor->grad), - inplace); - } - } break; - case GGML_RWKV_OP_SCALE: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_CPY: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_RESHAPE: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_VIEW: - { - GGML_RWKV_ASSERT(false); // not supported - } break; - case GGML_RWKV_OP_PERMUTE: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_TRANSPOSE: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_GET_ROWS: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_DIAG_MASK_INF: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_SOFT_MAX: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_ROPE: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_CONV_1D_1S: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_CONV_1D_2S: - { - GGML_RWKV_ASSERT(false); // TODO: not implemented - } break; - case GGML_RWKV_OP_FLASH_ATTN: - { - GGML_RWKV_ASSERT(false); // not supported - } break; - case GGML_RWKV_OP_FLASH_FF: - { - GGML_RWKV_ASSERT(false); // not supported - } break; - case GGML_RWKV_OP_NONE: - { - // nop - } break; - case GGML_RWKV_OP_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } -} - -static void ggml_rwkv_visit_parents(struct ggml_rwkv_cgraph * cgraph, struct ggml_rwkv_tensor * node) { - if (node->grad == NULL) { - // this usually happens when we generate intermediate nodes from constants in the backward pass - // it can also happen during forward pass, if the user performs computations with constants - if (node->op != GGML_RWKV_OP_NONE) { - //GGML_RWKV_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); - } - } - - // check if already visited - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i] == node) { - return; - } - } - - for (int i = 0; i < cgraph->n_leafs; i++) { - if (cgraph->leafs[i] == node) { - return; - } - } - - if (node->src0) { - ggml_rwkv_visit_parents(cgraph, node->src0); - } - - if (node->src1) { - ggml_rwkv_visit_parents(cgraph, node->src1); - } - - for (int i = 0; i < GGML_RWKV_MAX_OPT; ++i) { - if (node->opt[i]) { - ggml_rwkv_visit_parents(cgraph, node->opt[i]); - } - } - - if (node->op == GGML_RWKV_OP_NONE && node->grad == NULL) { - // reached a leaf node, not part of the gradient graph (e.g. a constant) - GGML_RWKV_ASSERT(cgraph->n_leafs < GGML_RWKV_MAX_NODES); - - cgraph->leafs[cgraph->n_leafs] = node; - cgraph->n_leafs++; - } else { - GGML_RWKV_ASSERT(cgraph->n_nodes < GGML_RWKV_MAX_NODES); - - cgraph->nodes[cgraph->n_nodes] = node; - cgraph->grads[cgraph->n_nodes] = node->grad; - cgraph->n_nodes++; - } -} - -static void ggml_rwkv_build_forward_impl(struct ggml_rwkv_cgraph * cgraph, struct ggml_rwkv_tensor * tensor, bool expand) { - if (!expand) { - cgraph->n_nodes = 0; - cgraph->n_leafs = 0; - } - - const int n0 = cgraph->n_nodes; - UNUSED(n0); - - ggml_rwkv_visit_parents(cgraph, tensor); - - const int n_new = cgraph->n_nodes - n0; - GGML_RWKV_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); - - if (n_new > 0) { - // the last added node should always be starting point - GGML_RWKV_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); - } -} - -void ggml_rwkv_build_forward_expand(struct ggml_rwkv_cgraph * cgraph, struct ggml_rwkv_tensor * tensor) { - ggml_rwkv_build_forward_impl(cgraph, tensor, true); -} - -struct ggml_rwkv_cgraph ggml_rwkv_build_forward(struct ggml_rwkv_tensor * tensor) { - struct ggml_rwkv_cgraph result = { - /*.n_nodes =*/ 0, - /*.n_leafs =*/ 0, - /*.n_threads =*/ 0, - /*.work_size =*/ 0, - /*.work =*/ NULL, - /*.nodes =*/ { NULL }, - /*.grads =*/ { NULL }, - /*.leafs =*/ { NULL }, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - }; - - ggml_rwkv_build_forward_impl(&result, tensor, false); - - return result; -} - -struct ggml_rwkv_cgraph ggml_rwkv_build_backward(struct ggml_rwkv_context * ctx, struct ggml_rwkv_cgraph * gf, bool keep) { - struct ggml_rwkv_cgraph result = *gf; - - GGML_RWKV_ASSERT(gf->n_nodes > 0); - - // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph - if (keep) { - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_rwkv_tensor * node = gf->nodes[i]; - - if (node->grad) { - node->grad = ggml_rwkv_dup_tensor(ctx, node); - gf->grads[i] = node->grad; - } - } - } - - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct ggml_rwkv_tensor * node = gf->nodes[i]; - - // because we detached the grad nodes from the original graph, we can afford inplace operations - if (node->grad) { - ggml_rwkv_compute_backward(ctx, node, keep); - } - } - - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct ggml_rwkv_tensor * node = gf->nodes[i]; - - if (node->is_param) { - GGML_RWKV_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_rwkv_build_forward_impl(&result, node->grad, true); - } - } - - return result; -} - -// -// thread data -// -// synchronization is done via busy loops -// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops -// - -#ifdef __APPLE__ - -//#include -// -//typedef os_unfair_lock ggml_rwkv_lock_t; -// -//#define ggml_rwkv_lock_init(x) UNUSED(x) -//#define ggml_rwkv_lock_destroy(x) UNUSED(x) -//#define ggml_rwkv_lock_lock os_unfair_lock_lock -//#define ggml_rwkv_lock_unlock os_unfair_lock_unlock -// -//#define GGML_RWKV_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT - -typedef int ggml_rwkv_lock_t; - -#define ggml_rwkv_lock_init(x) UNUSED(x) -#define ggml_rwkv_lock_destroy(x) UNUSED(x) -#define ggml_rwkv_lock_lock(x) UNUSED(x) -#define ggml_rwkv_lock_unlock(x) UNUSED(x) - -#define GGML_RWKV_LOCK_INITIALIZER 0 - -typedef pthread_t ggml_rwkv_thread_t; - -#define ggml_rwkv_thread_create pthread_create -#define ggml_rwkv_thread_join pthread_join - -#else - -//typedef pthread_spinlock_t ggml_rwkv_lock_t; - -//#define ggml_rwkv_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) -//#define ggml_rwkv_lock_destroy pthread_spin_destroy -//#define ggml_rwkv_lock_lock pthread_spin_lock -//#define ggml_rwkv_lock_unlock pthread_spin_unlock - -typedef int ggml_rwkv_lock_t; - -#define ggml_rwkv_lock_init(x) UNUSED(x) -#define ggml_rwkv_lock_destroy(x) UNUSED(x) -#define ggml_rwkv_lock_lock(x) UNUSED(x) -#define ggml_rwkv_lock_unlock(x) UNUSED(x) - -#define GGML_RWKV_LOCK_INITIALIZER 0 - -typedef pthread_t ggml_rwkv_thread_t; - -#define ggml_rwkv_thread_create pthread_create -#define ggml_rwkv_thread_join pthread_join - -#endif - -struct ggml_rwkv_compute_state_shared { - ggml_rwkv_lock_t spin; - - int n_threads; - - // synchronization primitives - atomic_int n_ready; - atomic_bool has_work; - atomic_bool stop; // stop all threads -}; - -struct ggml_rwkv_compute_state { - ggml_rwkv_thread_t thrd; - - struct ggml_rwkv_compute_params params; - struct ggml_rwkv_tensor * node; - - struct ggml_rwkv_compute_state_shared * shared; -}; - -static thread_ret_t ggml_rwkv_graph_compute_thread(void * data) { - struct ggml_rwkv_compute_state * state = (struct ggml_rwkv_compute_state *) data; - - const int n_threads = state->shared->n_threads; - - while (true) { - if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) { - atomic_store(&state->shared->has_work, false); - } else { - while (atomic_load(&state->shared->has_work)) { - if (atomic_load(&state->shared->stop)) { - return 0; - } - ggml_rwkv_lock_lock (&state->shared->spin); - ggml_rwkv_lock_unlock(&state->shared->spin); - } - } - - atomic_fetch_sub(&state->shared->n_ready, 1); - - // wait for work - while (!atomic_load(&state->shared->has_work)) { - if (atomic_load(&state->shared->stop)) { - return 0; - } - ggml_rwkv_lock_lock (&state->shared->spin); - ggml_rwkv_lock_unlock(&state->shared->spin); - } - - // check if we should stop - if (atomic_load(&state->shared->stop)) { - break; - } - - if (state->node) { - if (state->params.ith < state->params.nth) { - ggml_rwkv_compute_forward(&state->params, state->node); - } - - state->node = NULL; - } else { - break; - } - } - - return 0; -} - -void ggml_rwkv_graph_compute(struct ggml_rwkv_context * ctx, struct ggml_rwkv_cgraph * cgraph) { - const int n_threads = cgraph->n_threads; - - struct ggml_rwkv_compute_state_shared state_shared = { - /*.spin =*/ GGML_RWKV_LOCK_INITIALIZER, - /*.n_threads =*/ n_threads, - /*.n_ready =*/ 0, - /*.has_work =*/ false, - /*.stop =*/ false, - }; - struct ggml_rwkv_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_rwkv_compute_state)*(n_threads - 1)) : NULL; - - // create thread pool - if (n_threads > 1) { - ggml_rwkv_lock_init(&state_shared.spin); - - atomic_store(&state_shared.has_work, true); - - for (int j = 0; j < n_threads - 1; j++) { - workers[j] = (struct ggml_rwkv_compute_state) { - .thrd = 0, - .params = { - .type = GGML_RWKV_TASK_COMPUTE, - .ith = j + 1, - .nth = n_threads, - .wsize = cgraph->work ? ggml_rwkv_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }, - .node = NULL, - .shared = &state_shared, - }; - - int rc = ggml_rwkv_thread_create(&workers[j].thrd, NULL, ggml_rwkv_graph_compute_thread, &workers[j]); - GGML_RWKV_ASSERT(rc == 0); - UNUSED(rc); - } - } - - // initialize tasks + work buffer - { - size_t work_size = 0; - - // thread scheduling for the different operations - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_rwkv_tensor * node = cgraph->nodes[i]; - - switch (node->op) { - case GGML_RWKV_OP_DUP: - { - node->n_tasks = 1; - } break; - case GGML_RWKV_OP_ADD: - { - node->n_tasks = n_threads; - } break; - case GGML_RWKV_OP_SUB: - case GGML_RWKV_OP_MUL: - case GGML_RWKV_OP_DIV: - case GGML_RWKV_OP_SQR: - case GGML_RWKV_OP_SQRT: - case GGML_RWKV_OP_SUM: - case GGML_RWKV_OP_MEAN: - case GGML_RWKV_OP_REPEAT: - case GGML_RWKV_OP_ABS: - case GGML_RWKV_OP_SGN: - case GGML_RWKV_OP_NEG: - case GGML_RWKV_OP_EXP: - case GGML_RWKV_OP_1_MINUS_X: - case GGML_RWKV_OP_MAX: - case GGML_RWKV_OP_STEP: - case GGML_RWKV_OP_RELU: - case GGML_RWKV_OP_SIGMOID: - { - node->n_tasks = 1; - } break; - case GGML_RWKV_OP_GELU: - { - node->n_tasks = n_threads; - } break; - case GGML_RWKV_OP_SILU: - { - node->n_tasks = n_threads; - } break; - case GGML_RWKV_OP_NORM: - case GGML_RWKV_OP_RMS_NORM: - { - node->n_tasks = n_threads; - } break; - case GGML_RWKV_OP_MUL_MAT: - { - node->n_tasks = n_threads; - - // TODO: use different scheduling for different matrix sizes - //const int nr0 = ggml_rwkv_nrows(node->src0); - //const int nr1 = ggml_rwkv_nrows(node->src1); - - //node->n_tasks = MIN(n_threads, MAX(1, nr0/128)); - //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); - - size_t cur = 0; - - if (node->src0->type == GGML_RWKV_TYPE_F16 && node->src1->type == GGML_RWKV_TYPE_F32) { -#if defined(GGML_RWKV_USE_ACCELERATE) || defined(GGML_RWKV_USE_OPENBLAS) - if (ggml_rwkv_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - node->n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - cur = GGML_RWKV_TYPE_SIZE[GGML_RWKV_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); - //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); - //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); - //printf("cur = %zu\n", cur); - } else { - cur = GGML_RWKV_TYPE_SIZE[GGML_RWKV_TYPE_F16]*ggml_rwkv_nelements(node->src1); - } -#else - cur = GGML_RWKV_TYPE_SIZE[GGML_RWKV_TYPE_F16]*ggml_rwkv_nelements(node->src1); -#endif - } else if (node->src0->type == GGML_RWKV_TYPE_F32 && node->src1->type == GGML_RWKV_TYPE_F32) { - cur = 0; - } else if (node->src0->type == GGML_RWKV_TYPE_Q4_1_O && node->src1->type == GGML_RWKV_TYPE_F32) { -#if defined(__AVX2__) - cur = 0; -#else - // Assuming that src1 is a vector - // TODO Not sure whether this is correct - cur = GGML_RWKV_TYPE_SIZE[GGML_RWKV_TYPE_F32] * ggml_rwkv_nelements(node->src1); -#endif - } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_RWKV_TYPE_F32) { -#if defined(GGML_RWKV_USE_ACCELERATE) || defined(GGML_RWKV_USE_OPENBLAS) - if (ggml_rwkv_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - node->n_tasks = 1; - cur = GGML_RWKV_TYPE_SIZE[GGML_RWKV_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); - } else -#endif - { - cur = GGML_RWKV_TYPE_SIZE[node->src0->type]*ggml_rwkv_nelements(node->src1)/GGML_RWKV_BLCK_SIZE[node->src0->type]; - } - } else { - GGML_RWKV_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case GGML_RWKV_OP_SCALE: - { - node->n_tasks = n_threads; - } break; - case GGML_RWKV_OP_CPY: - case GGML_RWKV_OP_RESHAPE: - case GGML_RWKV_OP_VIEW: - case GGML_RWKV_OP_PERMUTE: - case GGML_RWKV_OP_TRANSPOSE: - case GGML_RWKV_OP_GET_ROWS: - case GGML_RWKV_OP_DIAG_MASK_INF: - { - node->n_tasks = 1; - } break; - case GGML_RWKV_OP_SOFT_MAX: - { - node->n_tasks = n_threads; - } break; - case GGML_RWKV_OP_ROPE: - { - node->n_tasks = 1; - } break; - case GGML_RWKV_OP_CONV_1D_1S: - case GGML_RWKV_OP_CONV_1D_2S: - { - node->n_tasks = n_threads; - - GGML_RWKV_ASSERT(node->src0->ne[3] == 1); - GGML_RWKV_ASSERT(node->src1->ne[2] == 1); - GGML_RWKV_ASSERT(node->src1->ne[3] == 1); - - size_t cur = 0; - const int nk = node->src0->ne[0]; - - if (node->src0->type == GGML_RWKV_TYPE_F16 && - node->src1->type == GGML_RWKV_TYPE_F32) { - cur = sizeof(ggml_rwkv_fp16_t)*( - nk*ggml_rwkv_up32(node->src0->ne[1])*node->src0->ne[2] + - ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] - ); - } else if (node->src0->type == GGML_RWKV_TYPE_F32 && - node->src1->type == GGML_RWKV_TYPE_F32) { - cur = sizeof(float)*( - nk*ggml_rwkv_up32(node->src0->ne[1])*node->src0->ne[2] + - ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] - ); - } else { - GGML_RWKV_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case GGML_RWKV_OP_FLASH_ATTN: - { - node->n_tasks = n_threads; - - size_t cur = 0; - - const int ne11 = ggml_rwkv_up(node->src1->ne[1], GGML_RWKV_SOFT_MAX_UNROLL); - - if (node->src1->type == GGML_RWKV_TYPE_F32) { - cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 - } - - if (node->src1->type == GGML_RWKV_TYPE_F16) { - cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_RWKV_OP_FLASH_FF: - { - node->n_tasks = n_threads; - - size_t cur = 0; - - if (node->src1->type == GGML_RWKV_TYPE_F32) { - cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 - } - - if (node->src1->type == GGML_RWKV_TYPE_F16) { - cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_RWKV_OP_NONE: - { - node->n_tasks = 1; - } break; - case GGML_RWKV_OP_COUNT: - { - GGML_RWKV_ASSERT(false); - } break; - } - } - - if (cgraph->work != NULL && work_size > cgraph->work_size) { - GGML_RWKV_ASSERT(false); // TODO: better handling - } - - if (work_size > 0 && cgraph->work == NULL) { - cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1); - - GGML_RWKV_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size); - cgraph->work = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_I8, cgraph->work_size); - } - } - - const int64_t perf_start_cycles = ggml_rwkv_perf_cycles(); - const int64_t perf_start_time_us = ggml_rwkv_perf_time_us(); - - for (int i = 0; i < cgraph->n_nodes; i++) { - GGML_RWKV_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes); - - struct ggml_rwkv_tensor * node = cgraph->nodes[i]; - - // TODO: this could be used to avoid unnecessary computations, but it needs to be improved - //if (node->grad == NULL && node->perf_runs > 0) { - // continue; - //} - - const int64_t perf_node_start_cycles = ggml_rwkv_perf_cycles(); - const int64_t perf_node_start_time_us = ggml_rwkv_perf_time_us(); - - // INIT - struct ggml_rwkv_compute_params params = { - /*.type =*/ GGML_RWKV_TASK_INIT, - /*.ith =*/ 0, - /*.nth =*/ node->n_tasks, - /*.wsize =*/ cgraph->work ? ggml_rwkv_nbytes(cgraph->work) : 0, - /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, - }; - - ggml_rwkv_compute_forward(¶ms, node); - - // COMPUTE - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_rwkv_lock_lock (&state_shared.spin); - ggml_rwkv_lock_unlock(&state_shared.spin); - } - - // launch thread pool - for (int j = 0; j < n_threads - 1; j++) { - workers[j].params = (struct ggml_rwkv_compute_params) { - .type = GGML_RWKV_TASK_COMPUTE, - .ith = j + 1, - .nth = node->n_tasks, - .wsize = cgraph->work ? ggml_rwkv_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }; - workers[j].node = node; - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) > 0) { - ggml_rwkv_lock_lock (&state_shared.spin); - ggml_rwkv_lock_unlock(&state_shared.spin); - } - - atomic_store(&state_shared.has_work, true); - } - - params.type = GGML_RWKV_TASK_COMPUTE; - ggml_rwkv_compute_forward(¶ms, node); - - // wait for thread pool - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_rwkv_lock_lock (&state_shared.spin); - ggml_rwkv_lock_unlock(&state_shared.spin); - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) != 0) { - ggml_rwkv_lock_lock (&state_shared.spin); - ggml_rwkv_lock_unlock(&state_shared.spin); - } - } - - // FINALIZE - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_rwkv_lock_lock (&state_shared.spin); - ggml_rwkv_lock_unlock(&state_shared.spin); - } - - // launch thread pool - for (int j = 0; j < n_threads - 1; j++) { - workers[j].params = (struct ggml_rwkv_compute_params) { - .type = GGML_RWKV_TASK_FINALIZE, - .ith = j + 1, - .nth = node->n_tasks, - .wsize = cgraph->work ? ggml_rwkv_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }; - workers[j].node = node; - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) > 0) { - ggml_rwkv_lock_lock (&state_shared.spin); - ggml_rwkv_lock_unlock(&state_shared.spin); - } - - atomic_store(&state_shared.has_work, true); - } - - params.type = GGML_RWKV_TASK_FINALIZE; - ggml_rwkv_compute_forward(¶ms, node); - - // wait for thread pool - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_rwkv_lock_lock (&state_shared.spin); - ggml_rwkv_lock_unlock(&state_shared.spin); - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) != 0) { - ggml_rwkv_lock_lock (&state_shared.spin); - ggml_rwkv_lock_unlock(&state_shared.spin); - } - } - - // performance stats (node) - { - int64_t perf_cycles_cur = ggml_rwkv_perf_cycles() - perf_node_start_cycles; - int64_t perf_time_us_cur = ggml_rwkv_perf_time_us() - perf_node_start_time_us; - - node->perf_runs++; - node->perf_cycles += perf_cycles_cur; - node->perf_time_us += perf_time_us_cur; - } - } - - // join thread pool - if (n_threads > 1) { - atomic_store(&state_shared.stop, true); - atomic_store(&state_shared.has_work, true); - - for (int j = 0; j < n_threads - 1; j++) { - int rc = ggml_rwkv_thread_join(workers[j].thrd, NULL); - GGML_RWKV_ASSERT(rc == 0); - UNUSED(rc); - } - - ggml_rwkv_lock_destroy(&state_shared.spin); - } - - // performance stats (graph) - { - int64_t perf_cycles_cur = ggml_rwkv_perf_cycles() - perf_start_cycles; - int64_t perf_time_us_cur = ggml_rwkv_perf_time_us() - perf_start_time_us; - - cgraph->perf_runs++; - cgraph->perf_cycles += perf_cycles_cur; - cgraph->perf_time_us += perf_time_us_cur; - - GGML_RWKV_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", - __func__, cgraph->perf_runs, - (double) perf_cycles_cur / (double) ggml_rwkv_cycles_per_ms(), - (double) cgraph->perf_cycles / (double) ggml_rwkv_cycles_per_ms() / (double) cgraph->perf_runs, - (double) perf_time_us_cur / 1000.0, - (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); - } -} - -void ggml_rwkv_graph_reset(struct ggml_rwkv_cgraph * cgraph) { - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_rwkv_tensor * grad = cgraph->grads[i]; - - if (grad) { - ggml_rwkv_set_zero(grad); - } - } -} - -void ggml_rwkv_graph_print(const struct ggml_rwkv_cgraph * cgraph) { - int64_t perf_total_per_op_us[GGML_RWKV_OP_COUNT] = {0}; - - GGML_RWKV_PRINT("=== GRAPH ===\n"); - - GGML_RWKV_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads); - GGML_RWKV_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size); - - GGML_RWKV_PRINT("n_nodes = %d\n", cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_rwkv_tensor * node = cgraph->nodes[i]; - - perf_total_per_op_us[node->op] += node->perf_time_us; - - GGML_RWKV_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", - i, - node->ne[0], node->ne[1], node->ne[2], - GGML_RWKV_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, - (double) node->perf_cycles / (double) ggml_rwkv_cycles_per_ms(), - (double) node->perf_cycles / (double) ggml_rwkv_cycles_per_ms() / (double) node->perf_runs, - (double) node->perf_time_us / 1000.0, - (double) node->perf_time_us / 1000.0 / node->perf_runs); - } - - GGML_RWKV_PRINT("n_leafs = %d\n", cgraph->n_leafs); - for (int i = 0; i < cgraph->n_leafs; i++) { - struct ggml_rwkv_tensor * node = cgraph->leafs[i]; - - GGML_RWKV_PRINT(" - %3d: [ %6d, %6d] %8s\n", - i, - node->ne[0], node->ne[1], - GGML_RWKV_OP_LABEL[node->op]); - } - - for (int i = 0; i < GGML_RWKV_OP_COUNT; i++) { - GGML_RWKV_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_RWKV_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0); - } - - GGML_RWKV_PRINT("========================================\n"); -} - -// check if node is part of the graph -static bool ggml_rwkv_graph_find(const struct ggml_rwkv_cgraph * cgraph, const struct ggml_rwkv_tensor * node) { - if (cgraph == NULL) { - return true; - } - - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i] == node) { - return true; - } - } - - return false; -} - -static struct ggml_rwkv_tensor * ggml_rwkv_graph_get_parent(const struct ggml_rwkv_cgraph * cgraph, const struct ggml_rwkv_tensor * node) { - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_rwkv_tensor * parent = cgraph->nodes[i]; - - if (parent->grad == node) { - return parent; - } - } - - return NULL; -} - -void ggml_rwkv_graph_dump_dot(const struct ggml_rwkv_cgraph * gb, const struct ggml_rwkv_cgraph * gf, const char * filename) { - char color[16]; - - FILE * fp = fopen(filename, "w"); - GGML_RWKV_ASSERT(fp); - - fprintf(fp, "digraph G {\n"); - fprintf(fp, " newrank = true;\n"); - fprintf(fp, " rankdir = LR;\n"); - - for (int i = 0; i < gb->n_nodes; i++) { - struct ggml_rwkv_tensor * node = gb->nodes[i]; - - if (ggml_rwkv_graph_get_parent(gb, node) != NULL) { - continue; - } - - if (node->is_param) { - snprintf(color, sizeof(color), "yellow"); - } else if (node->grad) { - if (ggml_rwkv_graph_find(gf, node)) { - snprintf(color, sizeof(color), "green"); - } else { - snprintf(color, sizeof(color), "lightblue"); - } - } else { - snprintf(color, sizeof(color), "white"); - } - - fprintf(fp, " \"%p\" [ \ -style = filled; fillcolor = %s; shape = record; \ -label=\"%d [%d, %d] | %s", - (void *) node, color, - i, node->ne[0], node->ne[1], - GGML_RWKV_OP_SYMBOL[node->op]); - - if (node->grad) { - fprintf(fp, " | %s\"; ]\n", GGML_RWKV_OP_SYMBOL[node->grad->op]); - } else { - fprintf(fp, "\"; ]\n"); - } - } - - for (int i = 0; i < gb->n_leafs; i++) { - struct ggml_rwkv_tensor * node = gb->leafs[i]; - - snprintf(color, sizeof(color), "pink"); - - if (ggml_rwkv_nelements(node) == 1) { - fprintf(fp, " \"%p\" [ \ -style = filled; fillcolor = %s; shape = record; \ -label=\"%.1e\"; ]\n", - (void *) node, color, (double)ggml_rwkv_get_f32_1d(node, 0)); - } else { - fprintf(fp, " \"%p\" [ \ -style = filled; fillcolor = %s; shape = record; \ -label=\"CONST %d [%d, %d]\"; ]\n", - (void *) node, color, - i, node->ne[0], node->ne[1]); - } - } - - for (int i = 0; i < gb->n_nodes; i++) { - struct ggml_rwkv_tensor * node = gb->nodes[i]; - - struct ggml_rwkv_tensor * parent = ggml_rwkv_graph_get_parent(gb, node); - - if (node->src0) { - struct ggml_rwkv_tensor * parent0 = ggml_rwkv_graph_get_parent(gb, node->src0); - - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n", - parent0 ? (void *) parent0 : (void *) node->src0, - parent0 ? "g" : "x", - parent ? (void *) parent : (void *) node, - parent ? "g" : "x", - parent ? "empty" : "vee", - parent ? "dashed" : "solid"); - } - - if (node->src1) { - struct ggml_rwkv_tensor * parent1 = ggml_rwkv_graph_get_parent(gb, node->src1); - - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n", - parent1 ? (void *) parent1 : (void *) node->src1, - parent1 ? "g" : "x", - parent ? (void *) parent : (void *) node, - parent ? "g" : "x", - parent ? "empty" : "vee", - parent ? "dashed" : "solid"); - } - } - - for (int i = 0; i < gb->n_leafs; i++) { - struct ggml_rwkv_tensor * node = gb->leafs[i]; - - if (node->src0) { - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n", - (void *) node->src0, "x", - (void *) node, "x"); - } - - if (node->src1) { - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n", - (void *) node->src1, "x", - (void *) node, "x"); - } - } - - fprintf(fp, "}\n"); - - fclose(fp); - - GGML_RWKV_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); -} - -//////////////////////////////////////////////////////////////////////////////// - -static void ggml_rwkv_opt_set_params(int np, struct ggml_rwkv_tensor * const ps[], const float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int ne = ggml_rwkv_nelements(ps[p]) ; - // TODO: add function to set tensor from array - for (int j = 0; j < ne; ++j) { - ggml_rwkv_set_f32_1d(ps[p], j, x[i++]); - } - } -} - -static void ggml_rwkv_opt_get_params(int np, struct ggml_rwkv_tensor * const ps[], float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int ne = ggml_rwkv_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int j = 0; j < ne; ++j) { - x[i++] = ggml_rwkv_get_f32_1d(ps[p], j); - } - } -} - -static void ggml_rwkv_opt_get_grad(int np, struct ggml_rwkv_tensor * const ps[], float * g) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int ne = ggml_rwkv_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int j = 0; j < ne; ++j) { - g[i++] = ggml_rwkv_get_f32_1d(ps[p]->grad, j); - } - } -} - -// -// ADAM -// -// ref: https://arxiv.org/pdf/1412.6980.pdf -// - -static enum ggml_rwkv_opt_result ggml_rwkv_opt_adam( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_opt_params params, - struct ggml_rwkv_tensor * f, - struct ggml_rwkv_cgraph * gf, - struct ggml_rwkv_cgraph * gb) { - GGML_RWKV_ASSERT(ggml_rwkv_is_scalar(f)); - - gf->n_threads = params.n_threads; - gb->n_threads = params.n_threads; - - // these will store the parameters we want to optimize - struct ggml_rwkv_tensor * ps[GGML_RWKV_MAX_PARAMS]; - - int np = 0; - int nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->is_param) { - GGML_RWKV_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - - GGML_RWKV_ASSERT(np < GGML_RWKV_MAX_PARAMS); - - ps[np++] = gf->nodes[i]; - nx += ggml_rwkv_nelements(gf->nodes[i]); - } - } - - // constants - const float alpha = params.adam.alpha; - const float beta1 = params.adam.beta1; - const float beta2 = params.adam.beta2; - const float eps = params.adam.eps; - - float * x = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // view of the parameters - float * g1 = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // gradient - float * g2 = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // gradient squared - float * m = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // first moment - float * v = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // second moment - float * mh = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // first moment hat - float * vh = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // second moment hat - - float * pf = params.past > 0 ? ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, params.past)->data : NULL; // past function values - - // initialize - ggml_rwkv_vec_set_f32(nx, m, 0.0f); - ggml_rwkv_vec_set_f32(nx, v, 0.0f); - - // update view - ggml_rwkv_opt_get_params(np, ps, x); - - // compute the function value - ggml_rwkv_graph_reset (gf); - ggml_rwkv_set_f32 (f->grad, 1.0f); - ggml_rwkv_graph_compute(ctx, gb); - - float fx_prev = ggml_rwkv_get_f32_1d(f, 0); - if (pf) { - pf[0] = fx_prev; - } - - int n_no_improvement = 0; - float fx_best = fx_prev; - - // run the optimizer - for (int t = 0; t < params.adam.n_iter; ++t) { - GGML_RWKV_PRINT_DEBUG ("=== iter %d ===\n", t); - - GGML_RWKV_PRINT_DEBUG ("f = %10.6f\n", ggml_rwkv_get_f32_1d(f, 0)); - GGML_RWKV_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_rwkv_get_f32_1d(ps[0]->grad, 0)); - GGML_RWKV_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_rwkv_get_f32_1d(ps[1]->grad, 0)); - - for (int i = 0; i < np; ++i) { - GGML_RWKV_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, - ggml_rwkv_get_f32_1d(ps[i], 0), ggml_rwkv_get_f32_1d(ps[i]->grad, 0)); - } - - const int64_t t_start_wall = ggml_rwkv_time_us(); - const int64_t t_start_cpu = ggml_rwkv_cycles(); - UNUSED(t_start_wall); - UNUSED(t_start_cpu); - - { - // update the gradient - ggml_rwkv_opt_get_grad(np, ps, g1); - - // m_t = beta1*m_t-1 + (1 - beta1)*g_t - ggml_rwkv_vec_scale_f32(nx, m, beta1); - ggml_rwkv_vec_mad_f32 (nx, m, g1, 1.0f - beta1); - - // g2 = g1^2 - ggml_rwkv_vec_sqr_f32 (nx, g2, g1); - - // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2 - ggml_rwkv_vec_scale_f32(nx, v, beta2); - ggml_rwkv_vec_mad_f32 (nx, v, g2, 1.0f - beta2); - - // m^hat = m_t / (1 - beta1^t) - // v^hat = v_t / (1 - beta2^t) - // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps) - ggml_rwkv_vec_cpy_f32 (nx, mh, m); - ggml_rwkv_vec_cpy_f32 (nx, vh, v); - - ggml_rwkv_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1))); - ggml_rwkv_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1))); - - ggml_rwkv_vec_sqrt_f32 (nx, vh, vh); - ggml_rwkv_vec_acc1_f32 (nx, vh, eps); - - ggml_rwkv_vec_div_f32 (nx, mh, mh, vh); - ggml_rwkv_vec_sub_f32 (nx, x, x, mh); - - // update the parameters - ggml_rwkv_opt_set_params(np, ps, x); - } - - ggml_rwkv_graph_reset (gf); - ggml_rwkv_set_f32 (f->grad, 1.0f); - ggml_rwkv_graph_compute(ctx, gb); - - const float fx = ggml_rwkv_get_f32_1d(f, 0); - - // check convergence - if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) { - GGML_RWKV_PRINT_DEBUG("converged\n"); - - return GGML_RWKV_OPT_OK; - } - - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= t) { - const float rate = (pf[t%params.past] - fx)/fx; - - if (fabsf(rate) < params.delta) { - return GGML_RWKV_OPT_OK; - } - } - - pf[t%params.past] = fx; - } - - // check for improvement - if (params.max_no_improvement > 0) { - if (fx_best > fx) { - fx_best = fx; - n_no_improvement = 0; - } else { - ++n_no_improvement; - - if (n_no_improvement >= params.max_no_improvement) { - return GGML_RWKV_OPT_OK; - } - } - } - - fx_prev = fx; - - { - const int64_t t_end_cpu = ggml_rwkv_cycles(); - GGML_RWKV_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); - UNUSED(t_end_cpu); - - const int64_t t_end_wall = ggml_rwkv_time_us(); - GGML_RWKV_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); - UNUSED(t_end_wall); - } - } - - return GGML_RWKV_OPT_DID_NOT_CONVERGE; -} - -// -// L-BFGS -// -// the L-BFGS implementation below is based on the following implementation: -// -// https://github.com/chokkan/liblbfgs -// - -struct ggml_rwkv_lbfgs_iteration_data { - float alpha; - float ys; - float * s; - float * y; -}; - -static enum ggml_rwkv_opt_result linesearch_backtracking( - struct ggml_rwkv_context * ctx, - const struct ggml_rwkv_opt_params * params, - int nx, - float * x, - float * fx, - float * g, - float * d, - float * step, - const float * xp, - struct ggml_rwkv_tensor * f, - struct ggml_rwkv_cgraph * gf, - struct ggml_rwkv_cgraph * gb, - const int np, - struct ggml_rwkv_tensor * ps[]) { - int count = 0; - - float width = 0.0f; - float dg = 0.0f; - float finit = 0.0f; - float dginit = 0.0f; - float dgtest = 0.0f; - - const float dec = 0.5f; - const float inc = 2.1f; - - if (*step <= 0.f) { - return GGML_RWKV_LINESEARCH_INVALID_PARAMETERS; - } - - // compute the initial gradient in the search direction - ggml_rwkv_vec_dot_f32(nx, &dginit, g, d); - - // make sure that d points to a descent direction - if (0 < dginit) { - return GGML_RWKV_LINESEARCH_FAIL; - } - - // initialize local variables - finit = *fx; - dgtest = params->lbfgs.ftol*dginit; - - while (true) { - ggml_rwkv_vec_cpy_f32(nx, x, xp); - ggml_rwkv_vec_mad_f32(nx, x, d, *step); - - // evaluate the function and gradient values - { - ggml_rwkv_opt_set_params(np, ps, x); - - ggml_rwkv_graph_reset (gf); - ggml_rwkv_set_f32 (f->grad, 1.0f); - ggml_rwkv_graph_compute(ctx, gb); - - ggml_rwkv_opt_get_grad(np, ps, g); - - *fx = ggml_rwkv_get_f32_1d(f, 0); - } - - ++count; - - if (*fx > finit + (*step)*dgtest) { - width = dec; - } else { - // Armijo condition is satisfied - if (params->lbfgs.linesearch == GGML_RWKV_LINESEARCH_BACKTRACKING_ARMIJO) { - return count; - } - - ggml_rwkv_vec_dot_f32(nx, &dg, g, d); - - // check the Wolfe condition - if (dg < params->lbfgs.wolfe * dginit) { - width = inc; - } else { - if(params->lbfgs.linesearch == GGML_RWKV_LINESEARCH_BACKTRACKING_WOLFE) { - // regular Wolfe conditions - return count; - } - - if(dg > -params->lbfgs.wolfe*dginit) { - width = dec; - } else { - // strong Wolfe condition (GGML_RWKV_LINESEARCH_BACKTRACKING_STRONG_WOLFE) - return count; - } - return count; - } - } - - if (*step < params->lbfgs.min_step) { - return GGML_RWKV_LINESEARCH_MINIMUM_STEP; - } - if (*step > params->lbfgs.max_step) { - return GGML_RWKV_LINESEARCH_MAXIMUM_STEP; - } - if (params->lbfgs.max_linesearch <= count) { - return GGML_RWKV_LINESEARCH_MAXIMUM_ITERATIONS; - } - - (*step) *= width; - } - - return GGML_RWKV_LINESEARCH_FAIL; -} - -static enum ggml_rwkv_opt_result ggml_rwkv_opt_lbfgs( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_opt_params params, - struct ggml_rwkv_tensor * f, - struct ggml_rwkv_cgraph * gf, - struct ggml_rwkv_cgraph * gb) { - if (params.lbfgs.linesearch == GGML_RWKV_LINESEARCH_BACKTRACKING_WOLFE || - params.lbfgs.linesearch == GGML_RWKV_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { - if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { - return GGML_RWKV_OPT_INVALID_WOLFE; - } - } - - gf->n_threads = params.n_threads; - gb->n_threads = params.n_threads; - - const int m = params.lbfgs.m; - - // these will store the parameters we want to optimize - struct ggml_rwkv_tensor * ps[GGML_RWKV_MAX_PARAMS]; - - int np = 0; - int nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->is_param) { - GGML_RWKV_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - - GGML_RWKV_ASSERT(np < GGML_RWKV_MAX_PARAMS); - - ps[np++] = gf->nodes[i]; - nx += ggml_rwkv_nelements(gf->nodes[i]); - } - } - - float * x = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // current parameters - float * xp = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // previous parameters - float * g = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // current gradient - float * gp = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // previous gradient - float * d = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; // search direction - - float * pf = params.past > 0 ? ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, params.past)->data : NULL; // past function values - - float fx = 0.0f; // cost function value - float xnorm = 0.0f; // ||x|| - float gnorm = 0.0f; // ||g|| - float step = 0.0f; - - // initialize x from the graph nodes - ggml_rwkv_opt_get_params(np, ps, x); - - // the L-BFGS memory - struct ggml_rwkv_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_rwkv_lbfgs_iteration_data)*m); - - for (int i = 0; i < m; ++i) { - lm[i].alpha = 0.0f; - lm[i].ys = 0.0f; - lm[i].s = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; - lm[i].y = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, nx)->data; - } - - // evaluate the function value and its gradient - { - ggml_rwkv_opt_set_params(np, ps, x); - - ggml_rwkv_graph_reset (gf); - ggml_rwkv_set_f32 (f->grad, 1.0f); - ggml_rwkv_graph_compute(ctx, gb); - - ggml_rwkv_opt_get_grad(np, ps, g); - - fx = ggml_rwkv_get_f32_1d(f, 0); - } - - if (pf) { - pf[0] = fx; - } - - float fx_best = fx; - - // search direction = -gradient - ggml_rwkv_vec_neg_f32(nx, d, g); - - // ||x||, ||g|| - ggml_rwkv_vec_norm_f32(nx, &xnorm, x); - ggml_rwkv_vec_norm_f32(nx, &gnorm, g); - - if (xnorm < 1.0f) { - xnorm = 1.0f; - } - - // already optimized - if (gnorm/xnorm <= params.lbfgs.eps) { - return GGML_RWKV_OPT_OK; - } - - // initial step - ggml_rwkv_vec_norm_inv_f32(nx, &step, d); - - int j = 0; - int k = 1; - int ls = 0; - int end = 0; - int bound = 0; - int n_no_improvement = 0; - - float ys = 0.0f; - float yy = 0.0f; - float beta = 0.0f; - - while (true) { - // store the current position and gradient vectors - ggml_rwkv_vec_cpy_f32(nx, xp, x); - ggml_rwkv_vec_cpy_f32(nx, gp, g); - - ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps); - - if (ls < 0) { - // linesearch failed - go back to the previous point and return - ggml_rwkv_vec_cpy_f32(nx, x, xp); - ggml_rwkv_vec_cpy_f32(nx, g, gp); - - return ls; - } - - ggml_rwkv_vec_norm_f32(nx, &xnorm, x); - ggml_rwkv_vec_norm_f32(nx, &gnorm, g); - - GGML_RWKV_PRINT_DEBUG("f = %10.6f\n", ggml_rwkv_get_f32_1d(f, 0)); - - if (xnorm < 1.0f) { - xnorm = 1.0f; - } - if (gnorm/xnorm <= params.lbfgs.eps) { - // converged - return GGML_RWKV_OPT_OK; - } - - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= k) { - const float rate = (pf[k%params.past] - fx)/fx; - - if (fabsf(rate) < params.delta) { - return GGML_RWKV_OPT_OK; - } - } - - pf[k%params.past] = fx; - } - - // check for improvement - if (params.max_no_improvement > 0) { - if (fx < fx_best) { - fx_best = fx; - n_no_improvement = 0; - } else { - n_no_improvement++; - - if (n_no_improvement >= params.max_no_improvement) { - return GGML_RWKV_OPT_OK; - } - } - } - - if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) { - // reached the maximum number of iterations - return GGML_RWKV_OPT_DID_NOT_CONVERGE; - } - - // update vectors s and y: - // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. - // y_{k+1} = g_{k+1} - g_{k}. - // - ggml_rwkv_vec_sub_f32(nx, lm[end].s, x, xp); - ggml_rwkv_vec_sub_f32(nx, lm[end].y, g, gp); - - // compute scalars ys and yy: - // ys = y^t \cdot s -> 1 / \rho. - // yy = y^t \cdot y. - // - ggml_rwkv_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s); - ggml_rwkv_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y); - - lm[end].ys = ys; - - // find new search direction - // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS - - bound = (m <= k) ? m : k; - k++; - end = (end + 1)%m; - - // initialize search direction with -g - ggml_rwkv_vec_neg_f32(nx, d, g); - - j = end; - for (int i = 0; i < bound; ++i) { - j = (j + m - 1) % m; - // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} - ggml_rwkv_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d); - lm[j].alpha /= lm[j].ys; - // q_{i} = q_{i+1} - \alpha_{i} y_{i} - ggml_rwkv_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha); - } - - ggml_rwkv_vec_scale_f32(nx, d, ys/yy); - - for (int i = 0; i < bound; ++i) { - // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} - ggml_rwkv_vec_dot_f32(nx, &beta, lm[j].y, d); - beta /= lm[j].ys; - // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} - ggml_rwkv_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta); - j = (j + 1)%m; - } - - step = 1.0; - } - - return GGML_RWKV_OPT_DID_NOT_CONVERGE; -} - -struct ggml_rwkv_opt_params ggml_rwkv_opt_default_params(enum ggml_rwkv_opt_type type) { - struct ggml_rwkv_opt_params result; - - switch (type) { - case GGML_RWKV_OPT_ADAM: - { - result = (struct ggml_rwkv_opt_params) { - .type = GGML_RWKV_OPT_ADAM, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 100, - - .print_forward_graph = true, - .print_backward_graph = true, - - .adam = { - .n_iter = 10000, - .alpha = 0.001f, - .beta1 = 0.9f, - .beta2 = 0.999f, - .eps = 1e-8f, - .eps_f = 1e-5f, - .eps_g = 1e-3f, - }, - }; - } break; - case GGML_RWKV_OPT_LBFGS: - { - result = (struct ggml_rwkv_opt_params) { - .type = GGML_RWKV_OPT_LBFGS, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 0, - - .print_forward_graph = true, - .print_backward_graph = true, - - .lbfgs = { - .m = 6, - .n_iter = 100, - .max_linesearch = 20, - - .eps = 1e-5f, - .ftol = 1e-4f, - .wolfe = 0.9f, - .min_step = 1e-20f, - .max_step = 1e+20f, - - .linesearch = GGML_RWKV_LINESEARCH_DEFAULT, - }, - }; - } break; - } - - return result; -} - -enum ggml_rwkv_opt_result ggml_rwkv_opt( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_opt_params params, - struct ggml_rwkv_tensor * f) { - bool free_ctx = false; - if (ctx == NULL) { - struct ggml_rwkv_init_params params_ctx = { - .mem_size = 16*1024*1024, - .mem_buffer = NULL, - }; - - ctx = ggml_rwkv_init(params_ctx); - if (ctx == NULL) { - return GGML_RWKV_OPT_NO_CONTEXT; - } - - free_ctx = true; - } - - enum ggml_rwkv_opt_result result = GGML_RWKV_OPT_OK; - - // build forward + backward compute graphs - struct ggml_rwkv_cgraph gf = ggml_rwkv_build_forward (f); - struct ggml_rwkv_cgraph gb = ggml_rwkv_build_backward(ctx, &gf, false); - - switch (params.type) { - case GGML_RWKV_OPT_ADAM: - { - result = ggml_rwkv_opt_adam(ctx, params, f, &gf, &gb); - } break; - case GGML_RWKV_OPT_LBFGS: - { - result = ggml_rwkv_opt_lbfgs(ctx, params, f, &gf, &gb); - } break; - } - - if (params.print_forward_graph) { - ggml_rwkv_graph_print (&gf); - ggml_rwkv_graph_dump_dot(&gf, NULL, "opt-forward.dot"); - } - - if (params.print_backward_graph) { - ggml_rwkv_graph_print (&gb); - ggml_rwkv_graph_dump_dot(&gb, &gf, "opt-backward.dot"); - } - - if (free_ctx) { - ggml_rwkv_free(ctx); - } - - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -size_t ggml_rwkv_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK == 0); - const int nb = k / QK; - - for (int j = 0; j < n; j += k) { - block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK; - - quantize_row_q4_0_reference(src + j, y, k); - - for (int i = 0; i < nb; i++) { - for (int l = 0; l < QK; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; - const uint8_t vi1 = y[i].qs[l/2] >> 4; - - hist[vi0]++; - hist[vi1]++; - } - } - } - - return (n/QK*sizeof(block_q4_0)); -} - -size_t ggml_rwkv_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK == 0); - const int nb = k / QK; - - for (int j = 0; j < n; j += k) { - block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK; - - quantize_row_q4_1_reference(src + j, y, k); - - for (int i = 0; i < nb; i++) { - for (int l = 0; l < QK; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; - const uint8_t vi1 = y[i].qs[l/2] >> 4; - - hist[vi0]++; - hist[vi1]++; - } - } - } - - return (n/QK*sizeof(block_q4_1)); -} - -size_t ggml_rwkv_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK == 0); - const int nb = k / QK; - - for (int j = 0; j < n; j += k) { - block_q4_1_o * restrict y = (block_q4_1_o *) dst + j / QK; - - quantize_row_q4_1_o_reference(src + j, y, k); - - for (int i = 0; i < nb; i++) { - for (int l = 0; l < QK; l += 2) { - const uint8_t vi0 = y[i].qs[l / 2] & 0xF; - const uint8_t vi1 = y[i].qs[l / 2] >> 4; - - hist[vi0]++; - hist[vi1]++; - } - } - } - - return (n / QK * sizeof(block_q4_1_o)); -} - -//////////////////////////////////////////////////////////////////////////////// - -int ggml_rwkv_cpu_has_avx(void) { -#if defined(__AVX__) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_avx2(void) { -#if defined(__AVX2__) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_avx512(void) { -#if defined(__AVX512F__) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_fma(void) { -#if defined(__FMA__) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_neon(void) { -#if defined(__ARM_NEON) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_arm_fma(void) { -#if defined(__ARM_FEATURE_FMA) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_f16c(void) { -#if defined(__F16C__) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_fp16_va(void) { -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_wasm_simd(void) { -#if defined(__wasm_simd128__) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_blas(void) { -#if defined(GGML_RWKV_USE_ACCELERATE) || defined(GGML_RWKV_USE_OPENBLAS) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_sse3(void) { -#if defined(__SSE3__) - return 1; -#else - return 0; -#endif -} - -int ggml_rwkv_cpu_has_vsx(void) { -#if defined(__POWER9_VECTOR__) - return 1; -#else - return 0; -#endif -} - -//////////////////////////////////////////////////////////////////////////////// - -#define GGML_RWKV_TEST_SET_ELEMENT_F32(tensor, i, value) *(float *) ((char *) tensor->data + 4 * i) = value - -#define GGML_RWKV_TEST_ASSERT(x, ...) do {\ - if (!(x)) {\ - fprintf(stderr, "*** Assertion failed ***\n");\ - fprintf(stderr, __VA_ARGS__);\ - fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ - abort();\ - }\ - } while (0) - -#define GGML_RWKV_TEST_ASSERT_ELEMENT_F32(tensor, i, expected_value) do {\ - float actual = *(float *) ((char *) tensor->data + 4 * i);\ - GGML_RWKV_TEST_ASSERT(fabsf(actual - expected_value) <= 0.0001F, "At %s[%d]: expected %f, actual %f", #tensor, i, expected_value, actual);\ - } while (0) - -// Copied from https://github.com/ggerganov/llama.cpp/blob/6e7801d08d81c931a5427bae46f00763e993f54a/tests/test-quantize.c -void ggml_rwkv_test_quantization(void) { - float src[QK]; - uint8_t dst[24]; - int64_t hist[16]; - - for (int i = 0; i < QK; i++) { - src[i] = (float) (i + 1); - } - - size_t size = ggml_rwkv_quantize_q4_0(src, dst, QK, QK, hist); - GGML_RWKV_TEST_ASSERT(size == 20, "%zd", size); - float max_result = ((float *) dst)[0]; - float max_expected = src[31] / ((1 << 3) - 1); - GGML_RWKV_TEST_ASSERT(max_result == max_expected, "%f, %f", max_result, max_expected); - for (int i = 0; i < QK; i++) { - uint8_t q4_result = (i % 2) ? (dst[sizeof(float) + i / 2] >> 4) : (dst[sizeof(float) + i / 2] & 0xF); - uint8_t q4_expected = roundf(src[i] / max_expected) + 8; - GGML_RWKV_TEST_ASSERT(q4_result == q4_expected, "%d, %d", q4_result, q4_expected); - } - - size = ggml_rwkv_quantize_q4_1(src, dst, QK, QK, hist); - GGML_RWKV_TEST_ASSERT(size == 24, "%zd", size); - float delta_result = ((float *) dst)[0]; - float delta_expected = (src[31] - src[0]) / ((1 << 4) - 1); - GGML_RWKV_TEST_ASSERT(delta_result == delta_expected, "%f, %f", delta_result, delta_expected); - float min_result = ((float *) dst)[1]; - float min_expected = src[0]; - GGML_RWKV_TEST_ASSERT(min_result == min_expected, "%f, %f", min_result, min_expected); - for (int i = 0; i < QK; i++) { - uint8_t q4_result = (i % 2) ? (dst[sizeof(float) * 2 + i / 2] >> 4) : (dst[sizeof(float) * 2 + i / 2] & 0xF); - uint8_t q4_expected = roundf((src[i] - min_expected) / delta_expected); - GGML_RWKV_TEST_ASSERT(q4_result == q4_expected, "%d, %d", q4_result, q4_expected); - } -} - -void ggml_rwkv_test_quantization_q4_1_o(void) { - float src[QK]; - uint8_t dst[24]; - int64_t hist[16]; - - for (int i = 0; i < QK; i++) { - src[i] = (float) (i + 1); - } - - size_t size = ggml_rwkv_quantize_q4_1_o(src, dst, QK, QK, hist); - GGML_RWKV_TEST_ASSERT(size == 24, "%zd", size); - - float delta_result = ggml_rwkv_half_to_float_reference(((block_q4_1_o *) dst)->d); - float delta_expected = (src[30] - src[0]) / ((1 << 4) - 1); - GGML_RWKV_TEST_ASSERT(delta_result == delta_expected, "%f, %f", delta_result, delta_expected); - - float min_result = ggml_rwkv_half_to_float_reference(((block_q4_1_o *) dst)->m); - float min_expected = src[0]; - GGML_RWKV_TEST_ASSERT(min_result == min_expected, "%f, %f", min_result, min_expected); - - uint16_t outlier_index = ((block_q4_1_o *) dst)->outlier_index; - uint16_t outlier_index_expected = 31; - GGML_RWKV_TEST_ASSERT(outlier_index == outlier_index_expected, "%d, %d", outlier_index, outlier_index_expected); - - float outlier_value_result = ggml_rwkv_half_to_float_reference(((block_q4_1_o *) dst)->outlier_value); - float outlier_value_expected = src[31]; - GGML_RWKV_TEST_ASSERT(outlier_value_result == outlier_value_expected, "%f, %f", outlier_value_result, outlier_value_expected); - - for (int i = 0; i < QK - 1; i++) { - uint8_t q4_result = (i % 2) ? (dst[sizeof(float) * 2 + i / 2] >> 4) : (dst[sizeof(float) * 2 + i / 2] & 0xF); - uint8_t q4_expected = roundf((src[i] - min_expected) / delta_expected); - GGML_RWKV_TEST_ASSERT(q4_result == q4_expected, "%d: %d, %d", i, q4_result, q4_expected); - } - - float dequantized[QK]; - dequantize_row_q4_1_o(dst, dequantized, QK); - - for (int i = 0; i < QK; i++) { - float actual = dequantized[i]; - float expected = src[i]; - float diff = fabsf(actual - expected); - // Difference looks huge, but the range is 0..31 -- compared to range, it is not that huge - GGML_RWKV_TEST_ASSERT(diff <= 1.0F, "%d: %f, %f", i, actual, expected); - } -} - -void ggml_rwkv_run_test_suite(void) { - ggml_rwkv_test_quantization(); - ggml_rwkv_test_quantization_q4_1_o(); - - struct ggml_rwkv_init_params params; - params.mem_size = 16 * 1024; - params.mem_buffer = NULL; - struct ggml_rwkv_context * ctx = ggml_rwkv_init(params); - - struct ggml_rwkv_tensor * a = ggml_rwkv_new_tensor_2d(ctx, GGML_RWKV_TYPE_F32, 3, 2); - GGML_RWKV_TEST_SET_ELEMENT_F32(a, 0, 1.0051F); - GGML_RWKV_TEST_SET_ELEMENT_F32(a, 1, 1.0484F); - GGML_RWKV_TEST_SET_ELEMENT_F32(a, 2, -0.4361F); - GGML_RWKV_TEST_SET_ELEMENT_F32(a, 3, -0.6984F); - GGML_RWKV_TEST_SET_ELEMENT_F32(a, 4, 1.7310F); - GGML_RWKV_TEST_SET_ELEMENT_F32(a, 5, -0.0446F); - - struct ggml_rwkv_tensor * b = ggml_rwkv_new_tensor_2d(ctx, GGML_RWKV_TYPE_F32, 3, 2); - GGML_RWKV_TEST_SET_ELEMENT_F32(b, 0, -0.2566F); - GGML_RWKV_TEST_SET_ELEMENT_F32(b, 1, -0.1412F); - GGML_RWKV_TEST_SET_ELEMENT_F32(b, 2, 1.6200F); - GGML_RWKV_TEST_SET_ELEMENT_F32(b, 3, 0.5156F); - GGML_RWKV_TEST_SET_ELEMENT_F32(b, 4, -0.3934F); - GGML_RWKV_TEST_SET_ELEMENT_F32(b, 5, -0.0694F); - - // Test against torch.exp(a) - struct ggml_rwkv_tensor * exp_a = ggml_rwkv_exp(ctx, a); - - struct ggml_rwkv_cgraph graph = ggml_rwkv_build_forward(exp_a); - graph.n_threads = 2; - ggml_rwkv_graph_compute(ctx, &graph); - - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(exp_a, 0, 2.7322F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(exp_a, 1, 2.8531F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(exp_a, 2, 0.6466F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(exp_a, 3, 0.4974F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(exp_a, 4, 5.6463F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(exp_a, 5, 0.9564F); - - // Test against (1 - a) in PyTorch - struct ggml_rwkv_tensor * one_minus_a = ggml_rwkv_1_minus_x(ctx, a); - - graph = ggml_rwkv_build_forward(one_minus_a); - graph.n_threads = 2; - ggml_rwkv_graph_compute(ctx, &graph); - - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(one_minus_a, 0, -0.0051F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(one_minus_a, 1, -0.0484F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(one_minus_a, 2, 1.4361F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(one_minus_a, 3, 1.6984F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(one_minus_a, 4, -0.7310F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(one_minus_a, 5, 1.0446F); - - // Test against torch.sigmoid(a) - struct ggml_rwkv_tensor * sigmoid_a = ggml_rwkv_sigmoid(ctx, a); - - graph = ggml_rwkv_build_forward(sigmoid_a); - graph.n_threads = 2; - ggml_rwkv_graph_compute(ctx, &graph); - - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 0, 0.7321F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 1, 0.7405F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 2, 0.3927F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 3, 0.3322F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 4, 0.8495F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 5, 0.4889F); - - // Test against torch.maximum(a, b) - struct ggml_rwkv_tensor * max_a_b = ggml_rwkv_max(ctx, a, b); - - graph = ggml_rwkv_build_forward(max_a_b); - graph.n_threads = 2; - ggml_rwkv_graph_compute(ctx, &graph); - - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(max_a_b, 0, 1.0051F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(max_a_b, 1, 1.0484F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(max_a_b, 2, 1.6200F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(max_a_b, 3, 0.5156F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(max_a_b, 4, 1.7310F); - GGML_RWKV_TEST_ASSERT_ELEMENT_F32(max_a_b, 5, -0.0446F); - - ggml_rwkv_free(ctx); -} diff --git a/otherarch/ggml_rwkv.h b/otherarch/ggml_rwkv.h deleted file mode 100644 index 97e8028cb..000000000 --- a/otherarch/ggml_rwkv.h +++ /dev/null @@ -1,645 +0,0 @@ -#pragma once - - -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include -#include - -#define GGML_RWKV_MAX_DIMS 4 -#define GGML_RWKV_MAX_NODES 4096 -#define GGML_RWKV_MAX_PARAMS 16 -#define GGML_RWKV_MAX_CONTEXTS 64 -#define GGML_RWKV_MAX_OPT 4 - -#ifdef __ARM_NEON -// we use the built-in 16-bit float type -typedef __fp16 ggml_rwkv_fp16_t; -#else -typedef uint16_t ggml_rwkv_fp16_t; -#endif - -// convert FP16 <-> FP32 -float ggml_rwkv_fp16_to_fp32(ggml_rwkv_fp16_t x); -ggml_rwkv_fp16_t ggml_rwkv_fp32_to_fp16(float x); - -struct ggml_rwkv_object; -struct ggml_rwkv_context; - -enum ggml_rwkv_type { - GGML_RWKV_TYPE_Q4_0, - // Stores min and delta per block, does quantized matmul. - GGML_RWKV_TYPE_Q4_1, - // Same as Q4_1, but stores outliers separately, and matmul is done in FP32. - // An outlier is the single absmax element in the quantized block. - GGML_RWKV_TYPE_Q4_1_O, - GGML_RWKV_TYPE_I8, - GGML_RWKV_TYPE_I16, - GGML_RWKV_TYPE_I32, - GGML_RWKV_TYPE_F16, - GGML_RWKV_TYPE_F32, - GGML_RWKV_TYPE_COUNT, -}; - -// available tensor operations: -enum ggml_rwkv_op { - GGML_RWKV_OP_NONE = 0, - - GGML_RWKV_OP_DUP, - GGML_RWKV_OP_ADD, - GGML_RWKV_OP_SUB, - GGML_RWKV_OP_MUL, - GGML_RWKV_OP_DIV, - GGML_RWKV_OP_SQR, - GGML_RWKV_OP_SQRT, - GGML_RWKV_OP_SUM, - GGML_RWKV_OP_MEAN, - GGML_RWKV_OP_REPEAT, - GGML_RWKV_OP_ABS, - GGML_RWKV_OP_SGN, - GGML_RWKV_OP_NEG, - // Element-wise exponential function `e^x`. - // Same as `torch.exp(x)` from PyTorch. - GGML_RWKV_OP_EXP, - // Element-wise `1 - x`. - GGML_RWKV_OP_1_MINUS_X, - - // Element-wise maximum of 2 values. Argument shapes must match. - // Same as `torch.maximum(x)` from PyTorch. - GGML_RWKV_OP_MAX, - - GGML_RWKV_OP_STEP, - GGML_RWKV_OP_RELU, - GGML_RWKV_OP_GELU, - // Element-wise sigmoid activation `1 / (1 + e^-x)`, also called logistic function. - // Same as `torch.sigmoid(x)` from PyTorch. - GGML_RWKV_OP_SIGMOID, - GGML_RWKV_OP_SILU, - GGML_RWKV_OP_NORM, // normalize - GGML_RWKV_OP_RMS_NORM, - - GGML_RWKV_OP_MUL_MAT, - - GGML_RWKV_OP_SCALE, - GGML_RWKV_OP_CPY, - GGML_RWKV_OP_RESHAPE, - GGML_RWKV_OP_VIEW, - GGML_RWKV_OP_PERMUTE, - GGML_RWKV_OP_TRANSPOSE, - GGML_RWKV_OP_GET_ROWS, - GGML_RWKV_OP_DIAG_MASK_INF, - GGML_RWKV_OP_SOFT_MAX, - GGML_RWKV_OP_ROPE, - GGML_RWKV_OP_CONV_1D_1S, - GGML_RWKV_OP_CONV_1D_2S, - - GGML_RWKV_OP_FLASH_ATTN, - GGML_RWKV_OP_FLASH_FF, - - GGML_RWKV_OP_COUNT, -}; - -// n-dimensional tensor -struct ggml_rwkv_tensor { - enum ggml_rwkv_type type; - - int n_dims; - int ne[GGML_RWKV_MAX_DIMS]; // number of elements - size_t nb[GGML_RWKV_MAX_DIMS]; // stride in bytes: - // nb[0] = sizeof(type) - // nb[1] = nb[0] * ne[0] + padding - // nb[i] = nb[i-1] * ne[i-1] - - // compute data - enum ggml_rwkv_op op; - - bool is_param; - - struct ggml_rwkv_tensor * grad; - struct ggml_rwkv_tensor * src0; - struct ggml_rwkv_tensor * src1; - struct ggml_rwkv_tensor * opt[GGML_RWKV_MAX_OPT]; - - // thread scheduling - int n_tasks; - - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; - - void * data; - char padding[8]; -}; - -// computation graph -struct ggml_rwkv_cgraph { - int n_nodes; - int n_leafs; - int n_threads; - - size_t work_size; - struct ggml_rwkv_tensor * work; - - struct ggml_rwkv_tensor * nodes[GGML_RWKV_MAX_NODES]; - struct ggml_rwkv_tensor * grads[GGML_RWKV_MAX_NODES]; - struct ggml_rwkv_tensor * leafs[GGML_RWKV_MAX_NODES]; - - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; -}; - -// scratch buffer -struct ggml_rwkv_scratch { - size_t offs; - size_t size; - void * data; -}; - -struct ggml_rwkv_init_params { - // memory pool - size_t mem_size; // bytes - void * mem_buffer; // if NULL, memory will be allocated internally -}; - -void ggml_rwkv_time_init(void); // call this once at the beginning of the program -int64_t ggml_rwkv_time_ms(void); -int64_t ggml_rwkv_time_us(void); -int64_t ggml_rwkv_cycles(void); -int64_t ggml_rwkv_cycles_per_ms(void); - -void ggml_rwkv_print_object (const struct ggml_rwkv_object * obj); -void ggml_rwkv_print_objects(const struct ggml_rwkv_context * ctx); - -int ggml_rwkv_nelements(const struct ggml_rwkv_tensor * tensor); -size_t ggml_rwkv_nbytes (const struct ggml_rwkv_tensor * tensor); - -int ggml_rwkv_blck_size (enum ggml_rwkv_type type); -size_t ggml_rwkv_type_size (enum ggml_rwkv_type type); // size in bytes for all elements in a block -float ggml_rwkv_type_sizef(enum ggml_rwkv_type type); // ggml_rwkv_type_size()/ggml_rwkv_blck_size() as float - -size_t ggml_rwkv_element_size(const struct ggml_rwkv_tensor * tensor); - -struct ggml_rwkv_context * ggml_rwkv_init(struct ggml_rwkv_init_params params); -void ggml_rwkv_free(struct ggml_rwkv_context * ctx); - -size_t ggml_rwkv_used_mem(const struct ggml_rwkv_context * ctx); - -size_t ggml_rwkv_set_scratch(struct ggml_rwkv_context * ctx, struct ggml_rwkv_scratch scratch); - -bool ggml_rwkv_mlock_supported(void); -bool ggml_rwkv_mlock(struct ggml_rwkv_context * ctx, char ** err_p); - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int n_dims, - const int *ne); - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor_1d( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int ne0); - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor_2d( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int ne0, - int ne1); - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor_3d( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int ne0, - int ne1, - int ne2); - -struct ggml_rwkv_tensor * ggml_rwkv_new_tensor_4d( - struct ggml_rwkv_context * ctx, - enum ggml_rwkv_type type, - int ne0, - int ne1, - int ne2, - int ne3); - -struct ggml_rwkv_tensor * ggml_rwkv_new_i32(struct ggml_rwkv_context * ctx, int32_t value); -struct ggml_rwkv_tensor * ggml_rwkv_new_f32(struct ggml_rwkv_context * ctx, float value); - -struct ggml_rwkv_tensor * ggml_rwkv_dup_tensor (struct ggml_rwkv_context * ctx, const struct ggml_rwkv_tensor * src); -struct ggml_rwkv_tensor * ggml_rwkv_view_tensor(struct ggml_rwkv_context * ctx, const struct ggml_rwkv_tensor * src); - -struct ggml_rwkv_tensor * ggml_rwkv_set_zero(struct ggml_rwkv_tensor * tensor); -struct ggml_rwkv_tensor * ggml_rwkv_set_i32 (struct ggml_rwkv_tensor * tensor, int32_t value); -struct ggml_rwkv_tensor * ggml_rwkv_set_f32 (struct ggml_rwkv_tensor * tensor, float value); - -int32_t ggml_rwkv_get_i32_1d(const struct ggml_rwkv_tensor * tensor, int i); -void ggml_rwkv_set_i32_1d(const struct ggml_rwkv_tensor * tensor, int i, int32_t value); - -float ggml_rwkv_get_f32_1d(const struct ggml_rwkv_tensor * tensor, int i); -void ggml_rwkv_set_f32_1d(const struct ggml_rwkv_tensor * tensor, int i, float value); - - void * ggml_rwkv_get_data (const struct ggml_rwkv_tensor * tensor); -float * ggml_rwkv_get_data_f32(const struct ggml_rwkv_tensor * tensor); - -// -// operations on tensors with backpropagation -// - -struct ggml_rwkv_tensor * ggml_rwkv_dup( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_add( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -struct ggml_rwkv_tensor * ggml_rwkv_sub( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -struct ggml_rwkv_tensor * ggml_rwkv_mul( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -struct ggml_rwkv_tensor * ggml_rwkv_div( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -struct ggml_rwkv_tensor * ggml_rwkv_sqr( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_sqrt( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -// return scalar -// TODO: compute sum along rows -struct ggml_rwkv_tensor * ggml_rwkv_sum( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -// mean along rows -struct ggml_rwkv_tensor * ggml_rwkv_mean( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -// if a is the same shape as b, and a is not parameter, return a -// otherwise, return a new tensor: repeat(a) to fit in b -struct ggml_rwkv_tensor * ggml_rwkv_repeat( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -struct ggml_rwkv_tensor * ggml_rwkv_abs( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_sgn( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_neg( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_exp( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_1_minus_x( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_max( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -struct ggml_rwkv_tensor * ggml_rwkv_step( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_relu( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -// TODO: double-check this computation is correct -struct ggml_rwkv_tensor * ggml_rwkv_gelu( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_sigmoid( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_silu( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -// normalize along rows -// TODO: eps is hardcoded to 1e-5 for now -struct ggml_rwkv_tensor * ggml_rwkv_norm( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_rms_norm( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -// A: m rows, n columns -// B: p rows, n columns (i.e. we transpose it internally) -// result is m columns, p rows -struct ggml_rwkv_tensor * ggml_rwkv_mul_mat( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -// -// operations on tensors without backpropagation -// - -// in-place, returns view(a) -struct ggml_rwkv_tensor * ggml_rwkv_scale( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -// a -> b, return view(b) -struct ggml_rwkv_tensor * ggml_rwkv_cpy( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -// return view(a), b specifies the new shape -// TODO: when we start computing gradient, make a copy instead of view -struct ggml_rwkv_tensor * ggml_rwkv_reshape( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -// return view(a) -// TODO: when we start computing gradient, make a copy instead of view -struct ggml_rwkv_tensor * ggml_rwkv_reshape_2d( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int ne0, - int ne1); - -// return view(a) -// TODO: when we start computing gradient, make a copy instead of view -struct ggml_rwkv_tensor * ggml_rwkv_reshape_3d( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int ne0, - int ne1, - int ne2); - -// offset in bytes -struct ggml_rwkv_tensor * ggml_rwkv_view_1d( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int ne0, - size_t offset); - -struct ggml_rwkv_tensor * ggml_rwkv_view_2d( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int ne0, - int ne1, - size_t nb1, // row stride in bytes - size_t offset); - -struct ggml_rwkv_tensor * ggml_rwkv_permute( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int axis0, - int axis1, - int axis2, - int axis3); - -// alias for ggml_rwkv_permute(ctx, a, 1, 0, 2, 3) -struct ggml_rwkv_tensor * ggml_rwkv_transpose( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -struct ggml_rwkv_tensor * ggml_rwkv_get_rows( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -// set elements above the diagonal to -INF -// in-place, returns view(a) -struct ggml_rwkv_tensor * ggml_rwkv_diag_mask_inf( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int n_past); - -// in-place, returns view(a) -struct ggml_rwkv_tensor * ggml_rwkv_soft_max( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a); - -// rotary position embedding -// in-place, returns view(a) -// if mode == 1, skip n_past elements -// TODO: avoid creating a new tensor every time -struct ggml_rwkv_tensor * ggml_rwkv_rope( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - int n_past, - int n_dims, - int mode); - -// padding = 1 -// TODO: we don't support extra parameters for now -// that's why we are hard-coding the stride, padding, and dilation -// not great .. -struct ggml_rwkv_tensor * ggml_rwkv_conv_1d_1s( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -struct ggml_rwkv_tensor * ggml_rwkv_conv_1d_2s( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b); - -struct ggml_rwkv_tensor * ggml_rwkv_flash_attn( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * q, - struct ggml_rwkv_tensor * k, - struct ggml_rwkv_tensor * v, - bool masked); - -struct ggml_rwkv_tensor * ggml_rwkv_flash_ff( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * a, - struct ggml_rwkv_tensor * b0, - struct ggml_rwkv_tensor * b1, - struct ggml_rwkv_tensor * c0, - struct ggml_rwkv_tensor * c1); - -// -// automatic differentiation -// - -void ggml_rwkv_set_param( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_tensor * tensor); - -void ggml_rwkv_build_forward_expand(struct ggml_rwkv_cgraph * cgraph, struct ggml_rwkv_tensor * tensor); - -struct ggml_rwkv_cgraph ggml_rwkv_build_forward (struct ggml_rwkv_tensor * tensor); -struct ggml_rwkv_cgraph ggml_rwkv_build_backward(struct ggml_rwkv_context * ctx, struct ggml_rwkv_cgraph * gf, bool keep); - -void ggml_rwkv_graph_compute(struct ggml_rwkv_context * ctx, struct ggml_rwkv_cgraph * cgraph); -void ggml_rwkv_graph_reset (struct ggml_rwkv_cgraph * cgraph); - -// print info and performance information for the graph -void ggml_rwkv_graph_print(const struct ggml_rwkv_cgraph * cgraph); - -// dump the graph into a file using the dot format -void ggml_rwkv_graph_dump_dot(const struct ggml_rwkv_cgraph * gb, const struct ggml_rwkv_cgraph * gf, const char * filename); - -// -// optimization -// - -// optimization methods -enum ggml_rwkv_opt_type { - GGML_RWKV_OPT_ADAM, - GGML_RWKV_OPT_LBFGS, -}; - -// linesearch methods -enum ggml_rwkv_linesearch { - GGML_RWKV_LINESEARCH_DEFAULT = 1, - - GGML_RWKV_LINESEARCH_BACKTRACKING_ARMIJO = 0, - GGML_RWKV_LINESEARCH_BACKTRACKING_WOLFE = 1, - GGML_RWKV_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, -}; - -// optimization return values -enum ggml_rwkv_opt_result { - GGML_RWKV_OPT_OK = 0, - GGML_RWKV_OPT_DID_NOT_CONVERGE, - GGML_RWKV_OPT_NO_CONTEXT, - GGML_RWKV_OPT_INVALID_WOLFE, - GGML_RWKV_OPT_FAIL, - - GGML_RWKV_LINESEARCH_FAIL = -128, - GGML_RWKV_LINESEARCH_MINIMUM_STEP, - GGML_RWKV_LINESEARCH_MAXIMUM_STEP, - GGML_RWKV_LINESEARCH_MAXIMUM_ITERATIONS, - GGML_RWKV_LINESEARCH_INVALID_PARAMETERS, -}; - -// optimization parameters -// -// see ggml.c (ggml_rwkv_opt_default_params) for default values -// -struct ggml_rwkv_opt_params { - enum ggml_rwkv_opt_type type; - - int n_threads; - - // delta-based convergence test - // - // if past == 0 - disabled - // if past > 0: - // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) - // - int past; - float delta; - - // maximum number of iterations without improvement - // - // if 0 - disabled - // if > 0: - // assume convergence if no cost improvement in this number of iterations - // - int max_no_improvement; - - bool print_forward_graph; - bool print_backward_graph; - - // ADAM parameters - struct { - int n_iter; - - float alpha; // learning rate - float beta1; - float beta2; - float eps; // epsilon for numerical stability - float eps_f; // epsilon for convergence test - float eps_g; // epsilon for convergence test - } adam; - - // LBFGS parameters - struct { - int m; // number of corrections to approximate the inv. Hessian - int n_iter; - int max_linesearch; - - float eps; // convergence tolerance - float ftol; // line search tolerance - float wolfe; - float min_step; - float max_step; - - enum ggml_rwkv_linesearch linesearch; - } lbfgs; -}; - -struct ggml_rwkv_opt_params ggml_rwkv_opt_default_params(enum ggml_rwkv_opt_type type); - -// optimize the function defined by the tensor f -enum ggml_rwkv_opt_result ggml_rwkv_opt( - struct ggml_rwkv_context * ctx, - struct ggml_rwkv_opt_params params, - struct ggml_rwkv_tensor * f); - -// -// quantization -// - -size_t ggml_rwkv_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); -size_t ggml_rwkv_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); -size_t ggml_rwkv_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist); - -// -// system info -// - -int ggml_rwkv_cpu_has_avx(void); -int ggml_rwkv_cpu_has_avx2(void); -int ggml_rwkv_cpu_has_avx512(void); -int ggml_rwkv_cpu_has_fma(void); -int ggml_rwkv_cpu_has_neon(void); -int ggml_rwkv_cpu_has_arm_fma(void); -int ggml_rwkv_cpu_has_f16c(void); -int ggml_rwkv_cpu_has_fp16_va(void); -int ggml_rwkv_cpu_has_wasm_simd(void); -int ggml_rwkv_cpu_has_blas(void); -int ggml_rwkv_cpu_has_sse3(void); -int ggml_rwkv_cpu_has_vsx(void); - -// Run test suite for ggml. -// Exits normally, if all tests pass. -// Aborts the execution if any test did not pass. -void ggml_rwkv_run_test_suite(); - -#ifdef __cplusplus -} -#endif diff --git a/otherarch/otherarch.h b/otherarch/otherarch.h index 3713be7ce..f194895e3 100644 --- a/otherarch/otherarch.h +++ b/otherarch/otherarch.h @@ -271,61 +271,4 @@ struct stablelm_model { std::map tensors; }; -struct rwkv_layer { - struct ggml_rwkv_tensor * ln1_weight; - struct ggml_rwkv_tensor * ln1_bias; - // RWKV, also called "attention" by the author. - struct ggml_rwkv_tensor * att_time_mix_k; - struct ggml_rwkv_tensor * att_time_mix_v; - struct ggml_rwkv_tensor * att_time_mix_r; - struct ggml_rwkv_tensor * att_time_first; - struct ggml_rwkv_tensor * att_time_decay; - struct ggml_rwkv_tensor * att_key; - struct ggml_rwkv_tensor * att_value; - struct ggml_rwkv_tensor * att_receptance; - struct ggml_rwkv_tensor * att_output; - - struct ggml_rwkv_tensor * ln2_weight; - struct ggml_rwkv_tensor * ln2_bias; - - // FFN. - struct ggml_rwkv_tensor * ffn_time_mix_k; - struct ggml_rwkv_tensor * ffn_time_mix_r; - struct ggml_rwkv_tensor * ffn_key; - struct ggml_rwkv_tensor * ffn_value; - struct ggml_rwkv_tensor * ffn_receptance; -}; - -struct rwkv_model { - int32_t n_vocab; - int32_t n_layer; - int32_t n_embed; - // 0 for float32, 1 for float16. - int32_t data_type; - - struct ggml_rwkv_tensor * emb; - - struct ggml_rwkv_tensor * ln0_weight; - struct ggml_rwkv_tensor * ln0_bias; - - std::vector layers; - - struct ggml_rwkv_tensor * ln_out_weight; - struct ggml_rwkv_tensor * ln_out_bias; - - struct ggml_rwkv_tensor * head; -}; -struct rwkv_context { - struct rwkv_model * model; - struct ggml_rwkv_tensor * token_index; - struct ggml_rwkv_tensor * state; - struct ggml_rwkv_tensor ** state_parts; - struct ggml_rwkv_tensor * logits; - struct ggml_rwkv_context * ctx; - struct ggml_rwkv_cgraph * graph; - bool freed; - float * state_in = 0; //stores input state, or use null for a new state - float * state_out = 0; //stores address of output state buffer - float * logits_out = 0; //stores address of output logit buffer -}; diff --git a/otherarch/rwkv.cpp b/otherarch/rwkv.cpp index cc4aa2b6e..702872020 100644 --- a/otherarch/rwkv.cpp +++ b/otherarch/rwkv.cpp @@ -1,10 +1,11 @@ //adapted from RWKV.cpp repo under MIT license // https://github.com/saharNooby/rwkv.cpp -#include "rwkv.h" -#include "ggml_rwkv.h" #include "otherarch.h" +#include "rwkv.h" +#include "ggml.h" + #include #include #include @@ -17,14 +18,10 @@ #include #include -#include "model_adapter.h" - #include "rwkv_vocab.cpp" // --- Utilities --- -#define FP32_SIZE 4 - // Checks that x is not false. If x is false, prints fancy message to stderr and returns 0. #define RWKV_ASSERT_FALSE(x, ...) \ do { \ @@ -51,22 +48,86 @@ bool read_int32(FILE * file, int32_t * dest) { return true; } -static const ggml_rwkv_type FORMAT_TYPE_TO_GGML_RWKV_TYPE[5] = { - GGML_RWKV_TYPE_F32, - GGML_RWKV_TYPE_F16, - GGML_RWKV_TYPE_Q4_0, - GGML_RWKV_TYPE_Q4_1, - GGML_RWKV_TYPE_Q4_1_O +#define GGML_TYPE_UNKNOWN GGML_TYPE_COUNT + +#define FORMAT_TYPE_COUNT 10 + +static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[FORMAT_TYPE_COUNT] = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_UNKNOWN, // Unused + GGML_TYPE_Q4_2, + GGML_TYPE_UNKNOWN, // Unused + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0 }; +static int32_t format_name_to_format_type(const char * format_name) { + if (strcmp(format_name, "Q4_0") == 0) return 2; + if (strcmp(format_name, "Q4_1") == 0) return 3; + if (strcmp(format_name, "Q4_2") == 0) return 5; + if (strcmp(format_name, "Q5_0") == 0) return 7; + if (strcmp(format_name, "Q5_1") == 0) return 8; + if (strcmp(format_name, "Q8_0") == 0) return 9; + + return -1; +} + // --- Model definition and loading utilities --- +struct rwkv_layer { + struct ggml_tensor * ln1_weight; + struct ggml_tensor * ln1_bias; + // RWKV, also called "attention" by the author. + struct ggml_tensor * att_time_mix_k; + struct ggml_tensor * att_time_mix_v; + struct ggml_tensor * att_time_mix_r; + struct ggml_tensor * att_time_first; + struct ggml_tensor * att_time_decay; + struct ggml_tensor * att_key; + struct ggml_tensor * att_value; + struct ggml_tensor * att_receptance; + struct ggml_tensor * att_output; + + struct ggml_tensor * ln2_weight; + struct ggml_tensor * ln2_bias; + + // FFN. + struct ggml_tensor * ffn_time_mix_k; + struct ggml_tensor * ffn_time_mix_r; + struct ggml_tensor * ffn_key; + struct ggml_tensor * ffn_value; + struct ggml_tensor * ffn_receptance; +}; + +struct rwkv_model { + int32_t n_vocab; + int32_t n_layer; + int32_t n_embed; + // 0 for float32, 1 for float16. + int32_t data_type; + + struct ggml_tensor * emb; + + struct ggml_tensor * ln0_weight; + struct ggml_tensor * ln0_bias; + + std::vector layers; + + struct ggml_tensor * ln_out_weight; + struct ggml_tensor * ln_out_bias; + + struct ggml_tensor * head; +}; // Finds model parameter by key and sets it into dest. // If the parameter was not found, returns false. -bool set_parameter(std::unordered_map * parameters, char * key, struct ggml_rwkv_tensor ** dest) { - struct ggml_rwkv_tensor * parameter = (*parameters)[key]; +bool set_parameter(std::unordered_map * parameters, char * key, struct ggml_tensor ** dest) { + struct ggml_tensor * parameter = (*parameters)[key]; RWKV_ASSERT_FALSE(parameter != NULL, "Parameter %s not found in model file", key); *dest = parameter; return true; @@ -74,7 +135,7 @@ bool set_parameter(std::unordered_map * // Finds block parameter by block index and key and sets it into dest. // If the parameter was not found, returns false. -bool set_block_parameter(std::unordered_map * parameters, int32_t block_index, char * key, struct ggml_rwkv_tensor ** dest) { +bool set_block_parameter(std::unordered_map * parameters, int32_t block_index, char * key, struct ggml_tensor ** dest) { char full_key[128]; sprintf(full_key, "blocks.%d.%s", block_index, key); return set_parameter(parameters, full_key, dest); @@ -82,18 +143,70 @@ bool set_block_parameter(std::unordered_mapn_layer > 0, "Non-positive n_layer %d", model->n_layer); read_int32(file, &(model->data_type)); + RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); + RWKV_ASSERT_NULL( - model->data_type == 0 || - model->data_type == 1 || - model->data_type == 2 || - model->data_type == 3 || - model->data_type == 4, - "Unsupported model data type %d", - model->data_type + model->data_type != 4, + "Models in Q4_1_O format cannot be loaded anymore because the format was removed. You need to quantize the model into another format" + ); + + RWKV_ASSERT_NULL( + model->data_type != 6, + "Models in Q4_3 format cannot be loaded anymore because the format was removed. You need to quantize the model into another format" ); // Parameter tensors would take at least this amount in memory. @@ -152,12 +267,13 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr size_t(256) * 1024 * 1024; // Initialize ggml - struct ggml_rwkv_init_params params; + struct ggml_init_params params; params.mem_size = memory_required; params.mem_buffer = NULL; - struct ggml_rwkv_context * ctx = ggml_rwkv_init(params); + params.no_alloc = false; + struct ggml_context * ctx = ggml_init(params); - std::unordered_map parameters; + std::unordered_map parameters; while (true) { int32_t dim_count; @@ -176,30 +292,24 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr int32_t data_type; read_int32(file, &data_type); - RWKV_ASSERT_NULL( - data_type == 0 || - data_type == 1 || - data_type == 2 || - data_type == 3 || - data_type == 4, - "Unsupported parameter data type %d", - data_type - ); + RWKV_ASSERT_NULL(data_type >= 0 && data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type); - ggml_rwkv_type ggml_rwkv_data_type = FORMAT_TYPE_TO_GGML_RWKV_TYPE[data_type]; + ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type]; - struct ggml_rwkv_tensor * tensor; + RWKV_ASSERT_NULL(ggml_data_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type); + + struct ggml_tensor * tensor; int32_t x = -1; int32_t y = -1; if (dim_count == 1) { read_int32(file, &x); - tensor = ggml_rwkv_new_tensor_1d(ctx, ggml_rwkv_data_type, x); + tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x); } else if (dim_count == 2) { read_int32(file, &x); read_int32(file, &y); - tensor = ggml_rwkv_new_tensor_2d(ctx, ggml_rwkv_data_type, x, y); + tensor = ggml_new_tensor_2d(ctx, ggml_data_type, x, y); } else { abort(); } @@ -207,7 +317,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr std::string key(key_length, 0); RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == uint32_t(key_length), "Failed to read parameter key"); - RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_rwkv_nbytes(tensor), file) == ggml_rwkv_nbytes(tensor), "Failed to read parameter data"); + RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_nbytes(tensor), file) == ggml_nbytes(tensor), "Failed to read parameter data"); parameters[key] = tensor; } @@ -255,26 +365,26 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr set_parameter(¶meters, "head.weight", &(model->head)); // Verify order of dimensions - struct ggml_rwkv_tensor * emb = model->emb; + struct ggml_tensor * emb = model->emb; RWKV_ASSERT_NULL(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); - RWKV_ASSERT_NULL(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[0]); - RWKV_ASSERT_NULL(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[1]); + RWKV_ASSERT_NULL(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %lld", emb->ne[0]); + RWKV_ASSERT_NULL(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %lld", emb->ne[1]); int32_t n_embed = model->n_embed; int32_t n_layer = model->n_layer; // Build graph - struct ggml_rwkv_tensor * state = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_F32, n_layer * 5 * n_embed); + struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); // x = self.w.emb.weight[token] - struct ggml_rwkv_tensor * token_index = ggml_rwkv_new_tensor_1d(ctx, GGML_RWKV_TYPE_I32, 1); - struct ggml_rwkv_tensor * x = ggml_rwkv_get_rows(ctx, model->emb, token_index); + struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index); // x = self.layer_norm(x, self.w.blocks[0].ln0) x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias); // We collect parts of new state here. Each part is (n_embed) vector. - struct ggml_rwkv_tensor ** state_parts = new ggml_rwkv_tensor * [n_layer * 5]; + struct ggml_tensor ** state_parts = new ggml_tensor * [n_layer * 5]; for (int i = 0; i < n_layer; i++) { auto layer = model->layers[i]; @@ -282,99 +392,99 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr // RWKV/time mixing { // self.layer_norm(x, self.w.blocks[i].ln1) - struct ggml_rwkv_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); + struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); // state[5 * i + 1] - struct ggml_rwkv_tensor * x_prev = ggml_rwkv_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * FP32_SIZE); + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * sizeof(float)); // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) - struct ggml_rwkv_tensor * xk = ggml_rwkv_add( + struct ggml_tensor * xk = ggml_add( ctx, - ggml_rwkv_mul(ctx, x0, layer.att_time_mix_k), - ggml_rwkv_mul(ctx, x_prev, ggml_rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ggml_mul(ctx, x0, layer.att_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) ); - struct ggml_rwkv_tensor * xv = ggml_rwkv_add( + struct ggml_tensor * xv = ggml_add( ctx, - ggml_rwkv_mul(ctx, x0, layer.att_time_mix_v), - ggml_rwkv_mul(ctx, x_prev, ggml_rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ggml_mul(ctx, x0, layer.att_time_mix_v), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) ); - struct ggml_rwkv_tensor * xr = ggml_rwkv_add( + struct ggml_tensor * xr = ggml_add( ctx, - ggml_rwkv_mul(ctx, x0, layer.att_time_mix_r), - ggml_rwkv_mul(ctx, x_prev, ggml_rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ggml_mul(ctx, x0, layer.att_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) ); // state[5 * i + 1] = x state_parts[5 * i + 1] = x0; // r = torch.sigmoid(rw @ xr) - struct ggml_rwkv_tensor * r = ggml_rwkv_sigmoid( + struct ggml_tensor * r = rwkv_sigmoid( ctx, - ggml_rwkv_mul_mat(ctx, layer.att_receptance, xr) + ggml_mul_mat(ctx, layer.att_receptance, xr) ); // k = kw @ xk - struct ggml_rwkv_tensor * k = ggml_rwkv_mul_mat(ctx, layer.att_key, xk); + struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); // v = vw @ xv - struct ggml_rwkv_tensor * v = ggml_rwkv_mul_mat(ctx, layer.att_value, xv); + struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); // aa = state[5 * i + 2] // bb = state[5 * i + 3] // pp = state[5 * i + 4] - struct ggml_rwkv_tensor * aa = ggml_rwkv_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * FP32_SIZE); - struct ggml_rwkv_tensor * bb = ggml_rwkv_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * FP32_SIZE); - struct ggml_rwkv_tensor * pp = ggml_rwkv_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * FP32_SIZE); + struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * sizeof(float)); + struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * sizeof(float)); + struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * sizeof(float)); // ww = time_first + k - struct ggml_rwkv_tensor * ww = ggml_rwkv_add(ctx, layer.att_time_first, k); + struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); // qq = torch.maximum(pp, ww) - struct ggml_rwkv_tensor * qq = ggml_rwkv_max(ctx, pp, ww); + struct ggml_tensor * qq = rwkv_max(ctx, pp, ww); // e1 = torch.exp(pp - qq) - struct ggml_rwkv_tensor * e1 = ggml_rwkv_exp(ctx, ggml_rwkv_sub(ctx, pp, qq)); + struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq)); // e2 = torch.exp(ww - qq) - struct ggml_rwkv_tensor * e2 = ggml_rwkv_exp(ctx, ggml_rwkv_sub(ctx, ww, qq)); + struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); // a = e1 * aa + e2 * v - struct ggml_rwkv_tensor * a = ggml_rwkv_add( + struct ggml_tensor * a = ggml_add( ctx, - ggml_rwkv_mul(ctx, e1, aa), - ggml_rwkv_mul(ctx, e2, v) + ggml_mul(ctx, e1, aa), + ggml_mul(ctx, e2, v) ); // b = e1 * bb + e2 - struct ggml_rwkv_tensor * b = ggml_rwkv_add( + struct ggml_tensor * b = ggml_add( ctx, - ggml_rwkv_mul(ctx, e1, bb), + ggml_mul(ctx, e1, bb), e2 ); // wkv = a / b - struct ggml_rwkv_tensor * wkv = ggml_rwkv_div(ctx, a, b); + struct ggml_tensor * wkv = ggml_div(ctx, a, b); // ww = pp + time_decay - ww = ggml_rwkv_add(ctx, pp, layer.att_time_decay); + ww = ggml_add(ctx, pp, layer.att_time_decay); // qq = torch.maximum(ww, k) - qq = ggml_rwkv_max(ctx, ww, k); + qq = rwkv_max(ctx, ww, k); // e1 = torch.exp(ww - qq) - e1 = ggml_rwkv_exp(ctx, ggml_rwkv_sub(ctx, ww, qq)); + e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); // e2 = torch.exp(k - qq) - e2 = ggml_rwkv_exp(ctx, ggml_rwkv_sub(ctx, k, qq)); + e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); // state[5 * i + 2] = e1 * aa + e2 * v - state_parts[5 * i + 2] = ggml_rwkv_add( + state_parts[5 * i + 2] = ggml_add( ctx, - ggml_rwkv_mul(ctx, e1, aa), - ggml_rwkv_mul(ctx, e2, v) + ggml_mul(ctx, e1, aa), + ggml_mul(ctx, e2, v) ); // state[5 * i + 3] = e1 * bb + e2 - state_parts[5 * i + 3] = ggml_rwkv_add( + state_parts[5 * i + 3] = ggml_add( ctx, - ggml_rwkv_mul(ctx, e1, bb), + ggml_mul(ctx, e1, bb), e2 ); // state[5 * i + 4] = qq state_parts[5 * i + 4] = qq; // ow @ (r * wkv) - x = ggml_rwkv_add( + x = ggml_add( ctx, x, - ggml_rwkv_mul_mat( + ggml_mul_mat( ctx, layer.att_output, - ggml_rwkv_mul(ctx, r, wkv) + ggml_mul(ctx, r, wkv) ) ); } @@ -382,42 +492,42 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr // FFN/channel mixing { // self.layer_norm(x, self.w.blocks[i].ln2) - struct ggml_rwkv_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); + struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); // state[5 * i + 0] - struct ggml_rwkv_tensor * x_prev = ggml_rwkv_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * FP32_SIZE); + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * sizeof(float)); // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) - struct ggml_rwkv_tensor * xk = ggml_rwkv_add( + struct ggml_tensor * xk = ggml_add( ctx, - ggml_rwkv_mul(ctx, x0, layer.ffn_time_mix_k), - ggml_rwkv_mul(ctx, x_prev, ggml_rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ggml_mul(ctx, x0, layer.ffn_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) ); - struct ggml_rwkv_tensor * xr = ggml_rwkv_add( + struct ggml_tensor * xr = ggml_add( ctx, - ggml_rwkv_mul(ctx, x0, layer.ffn_time_mix_r), - ggml_rwkv_mul(ctx, x_prev, ggml_rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ggml_mul(ctx, x0, layer.ffn_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) ); // state[5 * i + 0] = x state_parts[5 * i + 0] = x0; // r = torch.sigmoid(rw @ xr) - struct ggml_rwkv_tensor * r = ggml_rwkv_sigmoid( + struct ggml_tensor * r = rwkv_sigmoid( ctx, - ggml_rwkv_mul_mat(ctx, layer.ffn_receptance, xr) + ggml_mul_mat(ctx, layer.ffn_receptance, xr) ); // k = torch.square(torch.relu(kw @ xk)) - struct ggml_rwkv_tensor * k = ggml_rwkv_sqr(ctx, ggml_rwkv_relu( + struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu( ctx, - ggml_rwkv_mul_mat(ctx, layer.ffn_key, xk) + ggml_mul_mat(ctx, layer.ffn_key, xk) )); // r * (vw @ k) - x = ggml_rwkv_add( + x = ggml_add( ctx, x, - ggml_rwkv_mul( + ggml_mul( ctx, r, - ggml_rwkv_mul_mat(ctx, layer.ffn_value, k) + ggml_mul_mat(ctx, layer.ffn_value, k) ) ); } @@ -427,14 +537,14 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias); // x = (self.w.head.weight @ x).float() - struct ggml_rwkv_tensor * logits = ggml_rwkv_mul_mat(ctx, model->head, x); + struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x); - struct ggml_rwkv_cgraph * graph = (struct ggml_rwkv_cgraph *) calloc(1, sizeof(struct ggml_rwkv_cgraph)); + struct ggml_cgraph * graph = (struct ggml_cgraph *) calloc(1, sizeof(struct ggml_cgraph)); - *graph = ggml_rwkv_build_forward(logits); + *graph = ggml_build_forward(logits); for (int i = 0; i < n_layer * 5; i++) { - ggml_rwkv_build_forward_expand(graph, state_parts[i]); + ggml_build_forward_expand(graph, state_parts[i]); } graph->n_threads = n_threads; @@ -468,51 +578,59 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1); - ggml_rwkv_set_i32(ctx->token_index, 0); - ggml_rwkv_set_i32_1d(ctx->token_index, 0, token); + ggml_set_i32_1d(ctx->token_index, 0, token); if (state_in == NULL) { - ggml_rwkv_set_f32(ctx->state, 0.0F); + ggml_set_f32(ctx->state, 0.0F); for (int i = 0; i < n_layer; i++) { // state[5 * i + 4] = -1e30 - ggml_rwkv_set_f32( - ggml_rwkv_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * FP32_SIZE), + ggml_set_f32( + ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)), -1e30F ); } } else { - memcpy(ctx->state->data, state_in, ctx->state->ne[0] * FP32_SIZE); + memcpy(ctx->state->data, state_in, ctx->state->ne[0] * sizeof(float)); } - ggml_rwkv_graph_compute(ctx->ctx, ctx->graph); + ggml_graph_compute(ctx->ctx, ctx->graph); for (size_t i = 0; i < size_t(n_layer * 5); i++) { - struct ggml_rwkv_tensor * part = ctx->state_parts[i]; + struct ggml_tensor * part = ctx->state_parts[i]; - memcpy(state_out + i * n_embed, part->data, part->ne[0] * FP32_SIZE); + memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float)); } - memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * FP32_SIZE); - - // Uncomment to measure used memory for adding the value into get_memory_required_mb. - //fprintf(stderr, "Used mem: %d MB\n", ggml_rwkv_used_mem(ctx->ctx) / 1024 / 1024); + memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * sizeof(float)); return true; } void rwkv_free(struct rwkv_context * ctx) { - ggml_rwkv_free(ctx->ctx); - - delete ctx->model; - delete ctx->state_parts; - delete ctx; + ctx->model->layers.~vector(); + free(ctx->model); + delete[] ctx->state_parts; + ggml_free(ctx->ctx); + free(ctx->graph); + free(ctx); } -bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { - RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type); +bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name) { + int32_t format_type = format_name_to_format_type(format_name); - ggml_rwkv_type type = FORMAT_TYPE_TO_GGML_RWKV_TYPE[q_type]; + RWKV_ASSERT_FALSE(format_type != -1, "Unsupported format \"%s\"", format_name); + + ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[format_type]; + + RWKV_ASSERT_FALSE(type != GGML_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name); + + // Needed to initialize FP16 lookup table + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } printf("Loading model from '%s'\n", model_file_path_in); @@ -546,7 +664,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type); - data_type = q_type; + data_type = format_type; fout.write((char *) &n_vocab, sizeof(n_vocab)); fout.write((char *) &n_embed, sizeof(n_embed)); @@ -562,7 +680,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode std::vector work; std::vector data_u8; - std::vector data_f16; + std::vector data_f16; std::vector data_f32; std::vector hist_all(1 << 4, 0); @@ -580,6 +698,12 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode break; } + RWKV_ASSERT_FALSE(parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, "Invalid parameter data type %d", parameter_data_type); + + ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]; + + RWKV_ASSERT_FALSE(parameter_ggml_type != GGML_TYPE_UNKNOWN, "Invalid parameter data type %d", parameter_data_type); + int32_t nelements = 1; int32_t ne[2] = { 1, 1 }; for (int i = 0; i < n_dims; ++i) { @@ -591,16 +715,9 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode finp.read(&name[0], key_length); { - static const char * parameter_data_type_str[] = { - "F32", - "F16", - "Q4_0", - "Q4_1", - "Q4_1_O" - }; - printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); + printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ggml_type_name(parameter_ggml_type)); - total_size_orig += (size_t) (nelements * ggml_rwkv_type_sizef(FORMAT_TYPE_TO_GGML_RWKV_TYPE[parameter_data_type])); + total_size_orig += (size_t) (nelements * ggml_type_sizef(parameter_ggml_type)); } // Quantize only 2D tensors, except embedding and head matrices. @@ -619,17 +736,17 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode if (parameter_data_type == 1) { data_f16.resize(nelements); - finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_rwkv_fp16_t)); + finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_fp16_t)); data_f32.resize(nelements); for (int i = 0; i < nelements; ++i) { - data_f32[i] = ggml_rwkv_fp16_to_fp32(data_f16[i]); + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); } } else { data_f32.resize(nelements); finp.read(reinterpret_cast(data_f32.data()), nelements * sizeof(float)); } - parameter_data_type = q_type; + parameter_data_type = format_type; } else { const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t); data_u8.resize(nelements * bytes_per_element); @@ -655,23 +772,28 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode std::vector hist_cur(1 << 4, 0); switch (type) { - case GGML_RWKV_TYPE_Q4_0: - { - cur_size = ggml_rwkv_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - } break; - case GGML_RWKV_TYPE_Q4_1: - { - cur_size = ggml_rwkv_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - } break; - case GGML_RWKV_TYPE_Q4_1_O: - { - cur_size = ggml_rwkv_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - } break; - default: - { - fprintf(stderr, "unsupported quantization type %d\n", type); - return false; - } + case GGML_TYPE_Q4_0: + cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + break; + case GGML_TYPE_Q4_1: + cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + break; + case GGML_TYPE_Q4_2: + cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + break; + case GGML_TYPE_Q5_0: + cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + break; + case GGML_TYPE_Q5_1: + cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + break; + case GGML_TYPE_Q8_0: + cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + break; + default: { + fprintf(stderr, "unsupported quantization type %d\n", type); + return false; + } } fout.write(reinterpret_cast(work.data()), cur_size); @@ -697,7 +819,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode printf("original size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0); printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); - printf("compression ratio = %8.2f%\n", 1.0 * total_size_orig / total_size_new); + printf("compression ratio = %8.2f\n", 1.0 * total_size_orig / total_size_new); { int64_t sum_all = 0; @@ -726,18 +848,18 @@ const char * rwkv_get_system_info_string(void) { static std::string s; s = ""; - s += "AVX = " + std::to_string(ggml_rwkv_cpu_has_avx()) + " | "; - s += "AVX2 = " + std::to_string(ggml_rwkv_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_rwkv_cpu_has_avx512()) + " | "; - s += "FMA = " + std::to_string(ggml_rwkv_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(ggml_rwkv_cpu_has_neon()) + " | "; - s += "ARM_FMA = " + std::to_string(ggml_rwkv_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(ggml_rwkv_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_rwkv_cpu_has_fp16_va()) + " | "; - s += "WASM_SIMD = " + std::to_string(ggml_rwkv_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_rwkv_cpu_has_blas()) + " | "; - s += "SSE3 = " + std::to_string(ggml_rwkv_cpu_has_sse3()) + " | "; - s += "VSX = " + std::to_string(ggml_rwkv_cpu_has_vsx()) + " | "; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; return s.c_str(); -} +} \ No newline at end of file diff --git a/otherarch/rwkv.h b/otherarch/rwkv.h index 09a67538a..24b9535dd 100644 --- a/otherarch/rwkv.h +++ b/otherarch/rwkv.h @@ -52,12 +52,19 @@ extern "C" { // Frees all allocated memory and the context. RWKV_API void rwkv_free(struct rwkv_context * ctx); - // Quantizes FP32 or FP16 model to one of INT4 formats. + // Quantizes FP32 or FP16 model to one of quantized formats. // Returns false on any error. Error messages would be printed to stderr. // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16. // - model_file_path_out: quantized model will be written here. - // - q_type: set to 2 for GGML_RWKV_TYPE_Q4_0, set to 3 for GGML_RWKV_TYPE_Q4_1. - RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type); + // - format_name: must be one of available format names below. + // Available format names: + // - Q4_0 + // - Q4_1 + // - Q4_2 + // - Q5_0 + // - Q5_1 + // - Q8_0 + RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name); // Returns system information string. RWKV_API const char * rwkv_get_system_info_string(void); @@ -66,4 +73,4 @@ extern "C" { } #endif -#endif +#endif \ No newline at end of file