From 1904a3cae871b51a1126b6aa597c3b6c7654f4d4 Mon Sep 17 00:00:00 2001
From: Justine Tunney <jtunney@gmail.com>
Date: Sat, 3 Jun 2023 10:29:12 -0700
Subject: [PATCH] Sync llama.cpp to 6986c7835adc13ba3f9d933b95671bb1f3984dc6

---
 third_party/ggml/ggml.c   | 3737 +++++++++++++++++++++++++++++++++----
 third_party/ggml/ggml.h   |  209 ++-
 third_party/ggml/llama.cc |   77 +-
 3 files changed, 3666 insertions(+), 357 deletions(-)

diff --git a/third_party/ggml/ggml.c b/third_party/ggml/ggml.c
index 603bafa20..a9e3bd6f8 100644
--- a/third_party/ggml/ggml.c
+++ b/third_party/ggml/ggml.c
@@ -177,7 +177,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 // quantization
 //
 
-#if __AVX__ || __AVX2__ || __AVX512F__
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 // multiply int8_t, add results pairwise twice
 static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
     // Get absolute values of x vectors
@@ -190,6 +190,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
     return _mm_madd_epi16(ones, dot);
 }
 
+#if __AVX__ || __AVX2__ || __AVX512F__
 // horizontally add 8 floats
 static inline float hsum_float_8(const __m256 x) {
     __m128 res = _mm256_extractf128_ps(x, 1);
@@ -247,12 +248,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
     return _mm256_cvtepi32_ps(summed_pairs);
 }
 
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-    // Get absolute values of x vectors
-    const __m256i ax = _mm256_sign_epi8(x, x);
-    // Sign the values of the y vectors
-    const __m256i sy = _mm256_sign_epi8(y, x);
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
 #if __AVXVNNI__
     const __m256i zero = _mm256_setzero_si256();
     const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
@@ -264,6 +260,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
 #endif
 }
 
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
 static inline __m128i packNibbles( __m256i bytes )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -284,7 +295,74 @@ static inline __m128i packNibbles( __m256i bytes )
     return _mm_packus_epi16( r0, r1 );
 #endif
 }
-#else
+#elif defined(__AVX__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+    const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
+    __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
+    __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
+    const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytesl = _mm_or_si128(bytesl, bit_mask);
+    bytesh = _mm_or_si128(bytesh, bit_mask);
+    bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
+    bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
+    return _mm256_set_m128i(bytesh, bytesl);
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    // Load 16 bytes from memory
+    __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
+    __m128i tmph = _mm_srli_epi16(tmpl, 4);
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    tmpl = _mm_and_si128(lowMask, tmpl);
+    tmph = _mm_and_si128(lowMask, tmph);
+    return _mm256_set_m128i(tmph, tmpl);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
+    const __m128i ones = _mm_set1_epi16(1);
+    const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
+    const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
+    const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    const __m128i axl = _mm256_castsi256_si128(ax);
+    const __m128i axh = _mm256_extractf128_si256(ax, 1);
+    const __m128i syl = _mm256_castsi256_si128(sy);
+    const __m128i syh = _mm256_extractf128_si256(sy, 1);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+    const __m128i xl = _mm256_castsi256_si128(x);
+    const __m128i xh = _mm256_extractf128_si256(x, 1);
+    const __m128i yl = _mm256_castsi256_si128(y);
+    const __m128i yh = _mm256_extractf128_si256(y, 1);
+    // Get absolute values of x vectors
+    const __m128i axl = _mm_sign_epi8(xl, xl);
+    const __m128i axh = _mm_sign_epi8(xh, xh);
+    // Sign the values of the y vectors
+    const __m128i syl = _mm_sign_epi8(yl, xl);
+    const __m128i syh = _mm_sign_epi8(yh, xh);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
 static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -301,7 +379,19 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
     return _mm_packus_epi16( bytes1, bytes2);
 }
 #endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+    __m128 res_0 =_mm_hadd_ps(a, b);
+    __m128 res_1 =_mm_hadd_ps(c, d);
+    __m128 res =_mm_hadd_ps(res_0, res_1);
+    res =_mm_hadd_ps(res, res);
+    res =_mm_hadd_ps(res, res);
+
+    return _mm_cvtss_f32(res);
+}
 #endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 
 #if __ARM_NEON
 
@@ -1625,6 +1715,7 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
 inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
 inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }
 inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
 inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
 inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
@@ -1850,6 +1941,126 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     }
 
     *s = hsum_float_8(acc);
+#elif defined(__SSSE3__)
+    // set constants
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    const __m128i off = _mm_set1_epi8(8);
+
+    // Initialize accumulator with zeros
+    __m128 acc_0 = _mm_setzero_ps();
+    __m128 acc_1 = _mm_setzero_ps();
+    __m128 acc_2 = _mm_setzero_ps();
+    __m128 acc_3 = _mm_setzero_ps();
+
+    // First round without accumulation
+    {
+        _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        acc_0 = _mm_mul_ps( d_0_1, p0 );
+        acc_1 = _mm_mul_ps( d_0_1, p1 );
+        acc_2 = _mm_mul_ps( d_2_3, p2 );
+        acc_3 = _mm_mul_ps( d_2_3, p3 );
+    }
+
+    // Main loop
+    for (int i = 2; i < nb; i+=2) {
+        _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+        // Acummulate
+        acc_0 = _mm_add_ps(p0_d, acc_0);
+        acc_1 = _mm_add_ps(p1_d, acc_1);
+        acc_2 = _mm_add_ps(p2_d, acc_2);
+        acc_3 = _mm_add_ps(p3_d, acc_3);
+    }
+
+    *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
 #else
     // scalar
     float sumf = 0.0;
@@ -1942,7 +2153,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
     }
 
     *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
-#elif defined(__AVX2__)
+#elif defined(__AVX2__) || defined(__AVX__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
 
@@ -1965,10 +2176,14 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
 
-        const __m256 xy = mul_sum_i8_pairs_float(bx, by);
+        const __m256 xy = mul_sum_us8_pairs_float(bx, by);
 
         // Accumulate d0*d1*x*y
+#if defined(__AVX2__)
         acc = _mm256_fmadd_ps( d0d1, xy, acc );
+#else
+        acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+#endif
     }
 
     *s = hsum_float_8(acc) + summs;
@@ -2179,6 +2394,37 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
         acc = _mm256_fmadd_ps(d, q, acc);
     }
 
+    *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8((char)0xF0);
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_andnot_si128(bxhil, mask);
+        bxhih = _mm_andnot_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx);
+        __m128i bxh = _mm256_extractf128_si256(bx, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx = _mm256_set_m128i(bxh, bxl);
+
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+    }
+
     *s = hsum_float_8(acc);
 #else
     // scalar
@@ -2402,11 +2648,45 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const __m256 dy = _mm256_broadcast_ss(&y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
 
         acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
     }
 
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8(0x10);
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
+
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_and_si128(bxhil, mask);
+        bxhih = _mm_and_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx);
+        __m128i bxh = _mm256_extractf128_si256(bx, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx = _mm256_set_m128i(bxh, bxl);
+
+        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
+
+        acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
+    }
+
     *s = hsum_float_8(acc) + summs;
 #else
     // scalar
@@ -2497,7 +2777,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
     }
 
     *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__AVX2__)
+#elif defined(__AVX2__) || defined(__AVX__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
 
@@ -2511,7 +2791,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
         const __m256 q = mul_sum_i8_pairs_float(bx, by);
 
         // Multiply q with scale and accumulate
+#if defined(__AVX2__)
         acc = _mm256_fmadd_ps( d, q, acc );
+#else
+        acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
+#endif
     }
 
     *s = hsum_float_8(acc);
@@ -2652,11 +2936,19 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
 inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
+inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);   }
 inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
 inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
 inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
+static const float GELU_COEF_A    = 0.044715f;
+static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+inline static float ggml_gelu_f32(float x) {
+    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
 #ifdef GGML_GELU_FP16
 inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
     uint16_t t;
@@ -2698,6 +2990,29 @@ inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
 }
 #endif
 
+inline static float ggml_silu_backward_f32(float x, float dy) {
+    const float s = 1.0f/(1.0f + expf(-x));
+    return dy*s*(1.0f + x*(1.0f - s));
+}
+
+#ifdef GGML_SILU_FP16
+inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
+    for (int i = 0; i < n; ++i) {
+        // we did not use x[i] to compute forward silu but its f16 equivalent
+        // take derivative at f16 of x[i]:
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        float usedx = GGML_FP16_TO_FP32(fp16);
+        dx[i] = ggml_silu_backward_f32(usedx, dy[i]);
+    }
+}
+#else
+inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
+    for (int i = 0; i < n; ++i) {
+        dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
+    }
+}
+#endif
+
 inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     int i = 0;
@@ -2839,12 +3154,16 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
 
     "DUP",
     "ADD",
+    "ADD1",
+    "ACC",
     "SUB",
     "MUL",
     "DIV",
     "SQR",
     "SQRT",
+    "LOG",
     "SUM",
+    "SUM_ROWS",
     "MEAN",
     "REPEAT",
     "ABS",
@@ -2854,12 +3173,15 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "RELU",
     "GELU",
     "SILU",
+    "SILU_BACK",
     "NORM",
     "RMS_NORM",
+    "RMS_NORM_BACK",
 
     "MUL_MAT",
 
     "SCALE",
+    "SET",
     "CPY",
     "CONT",
     "RESHAPE",
@@ -2867,9 +3189,13 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "PERMUTE",
     "TRANSPOSE",
     "GET_ROWS",
+    "GET_ROWS_BACK",
+    "DIAG",
     "DIAG_MASK_INF",
+    "DIAG_MASK_ZERO",
     "SOFT_MAX",
     "ROPE",
+    "ROPE_BACK",
     "ALIBI",
     "CONV_1D_1S",
     "CONV_1D_2S",
@@ -2881,19 +3207,23 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "MAP_BINARY",
 };
 
-static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
+static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
 
     "x",
     "x+y",
+    "x+y",
+    "view(x,nb,offset)+=y->x",
     "x-y",
     "x*y",
     "x/y",
     "x^2",
     "√x",
+    "log(x)",
     "Σx",
+    "Σx_k",
     "Σx/n",
     "repeat(x)",
     "abs(x)",
@@ -2903,12 +3233,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "relu(x)",
     "gelu(x)",
     "silu(x)",
+    "silu_back(x)",
     "norm(x)",
     "rms_norm(x)",
+    "rms_norm_back(x)",
 
     "X*Y",
 
     "x*v",
+    "y-\\>view(x)",
     "x-\\>y",
     "cont(x)",
     "reshape(x)",
@@ -2916,9 +3249,13 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "permute(x)",
     "transpose(x)",
     "get_rows(x)",
+    "get_rows_back(x)",
+    "diag(x)",
     "diag_mask_inf(x)",
+    "diag_mask_zero(x)",
     "soft_max(x)",
     "rope(x)",
+    "rope_back(x)",
     "alibi(x)",
     "conv_1d_1s(x)",
     "conv_1d_2s(x)",
@@ -2930,7 +3267,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "f(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
+static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3169,9 +3506,9 @@ static inline int ggml_up32(int n) {
     return (n + 31) & ~31;
 }
 
-static inline int ggml_up64(int n) {
-    return (n + 63) & ~63;
-}
+//static inline int ggml_up64(int n) {
+//    return (n + 63) & ~63;
+//}
 
 static inline int ggml_up(int n, int m) {
     // assert m is a power of 2
@@ -3310,6 +3647,20 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
     return result;
 }
 
+// IMPORTANT:
+// when creating "opt" tensors, always save and load the scratch buffer
+// this is an error prone process, but it is necessary to support inplace
+// operators when using scratch buffers
+// TODO: implement a better way
+void ggml_scratch_save(struct ggml_context * ctx) {
+    ctx->scratch_save = ctx->scratch;
+    ctx->scratch.data = NULL;
+}
+
+void ggml_scratch_load(struct ggml_context * ctx) {
+    ctx->scratch = ctx->scratch_save;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_tensor * ggml_new_tensor_impl(
@@ -3398,6 +3749,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
 
     *result = (struct ggml_tensor) {
         /*.type         =*/ type,
+        /*.backend      =*/ GGML_BACKEND_CPU,
         /*.n_dims       =*/ n_dims,
         /*.ne           =*/ { 1, 1, 1, 1 },
         /*.nb           =*/ { 0, 0, 0, 0 },
@@ -3480,12 +3832,11 @@ struct ggml_tensor * ggml_new_tensor_4d(
 }
 
 struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
-    ctx->scratch_save = ctx->scratch;
-    ctx->scratch.data = NULL;
+    ggml_scratch_save(ctx);
 
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
 
-    ctx->scratch = ctx->scratch_save;
+    ggml_scratch_load(ctx);
 
     ggml_set_i32(result, value);
 
@@ -3493,12 +3844,11 @@ struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
 }
 
 struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
-    ctx->scratch_save = ctx->scratch;
-    ctx->scratch.data = NULL;
+    ggml_scratch_save(ctx);
 
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
 
-    ctx->scratch = ctx->scratch_save;
+    ggml_scratch_load(ctx);
 
     ggml_set_f32(result, value);
 
@@ -3864,6 +4214,113 @@ struct ggml_tensor * ggml_add_inplace(
     return ggml_add_impl(ctx, a, b, true);
 }
 
+// ggml_add1
+
+struct ggml_tensor * ggml_add1_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        bool inplace) {
+    GGML_ASSERT(ggml_is_scalar(b));
+    GGML_ASSERT(ggml_is_padded_1d(a));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_ADD1;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_add1(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_add1_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_add1_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_add1_impl(ctx, a, b, true);
+}
+
+// ggml_acc
+
+struct ggml_tensor * ggml_acc_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        size_t               nb1,
+        size_t               nb2,
+        size_t               nb3,
+        size_t               offset,
+        bool inplace) {
+    GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(a->type == GGML_TYPE_F32);
+    GGML_ASSERT(b->type == GGML_TYPE_F32);
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5);
+
+    ((int32_t *) c->data)[0] = nb1;
+    ((int32_t *) c->data)[1] = nb2;
+    ((int32_t *) c->data)[2] = nb3;
+    ((int32_t *) c->data)[3] = offset;
+    ((int32_t *) c->data)[4] = inplace ? 1 : 0;
+
+    ggml_scratch_load(ctx);
+
+    result->op   = GGML_OP_ACC;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+    result->opt[0] = c;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_acc(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        size_t               nb1,
+        size_t               nb2,
+        size_t               nb3,
+        size_t               offset) {
+    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+}
+
+struct ggml_tensor * ggml_acc_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        size_t               nb1,
+        size_t               nb2,
+        size_t               nb3,
+        size_t               offset) {
+    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
+}
+
 // ggml_sub
 
 struct ggml_tensor * ggml_sub_impl(
@@ -4057,6 +4514,41 @@ struct ggml_tensor * ggml_sqrt_inplace(
     return ggml_sqrt_impl(ctx, a, true);
 }
 
+
+// ggml_log
+
+struct ggml_tensor * ggml_log_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_LOG;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_log(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_log_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_log_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_log_impl(ctx, a, true);
+}
+
 // ggml_sum
 
 struct ggml_tensor * ggml_sum(
@@ -4078,6 +4570,33 @@ struct ggml_tensor * ggml_sum(
     return result;
 }
 
+
+// ggml_sum_rows
+
+struct ggml_tensor * ggml_sum_rows(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    int64_t ne[4] = {1,1,1,1};
+    for (int i=1; i<a->n_dims; ++i) {
+        ne[i] = a->ne[i];
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, a->n_dims, ne);
+
+    result->op   = GGML_OP_SUM_ROWS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
 // ggml_mean
 
 struct ggml_tensor * ggml_mean(
@@ -4368,6 +4887,29 @@ struct ggml_tensor * ggml_silu_inplace(
     return ggml_silu_impl(ctx, a, true);
 }
 
+// ggml_silu_back
+
+struct ggml_tensor * ggml_silu_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SILU_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_norm
 
 struct ggml_tensor * ggml_norm_impl(
@@ -4410,7 +4952,6 @@ struct ggml_tensor * ggml_rms_norm_impl(
     bool is_node = false;
 
     if (!inplace && (a->grad)) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -4436,6 +4977,28 @@ struct ggml_tensor * ggml_rms_norm_inplace(
     return ggml_rms_norm_impl(ctx, a, true);
 }
 
+struct ggml_tensor * ggml_rms_norm_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    bool is_node = false;
+
+    if (a->grad) {
+        // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_RMS_NORM_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+
 // ggml_mul_mat
 
 struct ggml_tensor * ggml_mul_mat(
@@ -4475,13 +5038,10 @@ struct ggml_tensor * ggml_scale_impl(
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
-    // TODO: when implement backward, fix this:
-    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     result->op   = GGML_OP_SCALE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4505,6 +5065,106 @@ struct ggml_tensor * ggml_scale_inplace(
     return ggml_scale_impl(ctx, a, b, true);
 }
 
+// ggml_set
+
+struct ggml_tensor * ggml_set_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset,
+        bool inplace) {
+    GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b));
+
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    // make a view of the destination
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5);
+
+    (( int32_t * ) c->data)[0] = nb1;
+    (( int32_t * ) c->data)[1] = nb2;
+    (( int32_t * ) c->data)[2] = nb3;
+    (( int32_t * ) c->data)[3] = offset;
+    (( int32_t * ) c->data)[4] = inplace ? 1 : 0;
+
+    ggml_scratch_load(ctx);
+
+    result->op   = GGML_OP_SET;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+    result->opt[0] = c;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_set(
+        struct ggml_context * ctx,
+        struct ggml_tensor *  a,
+        struct ggml_tensor *  b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+}
+
+struct ggml_tensor * ggml_set_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor *  a,
+        struct ggml_tensor *  b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
+}
+
+struct ggml_tensor * ggml_set_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor *  a,
+        struct ggml_tensor *  b,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);
+}
+
+struct ggml_tensor * ggml_set_1d_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor *  a,
+        struct ggml_tensor *  b,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);
+}
+
+struct ggml_tensor * ggml_set_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor *  a,
+        struct ggml_tensor *  b,
+        size_t                nb1,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
+}
+
+struct ggml_tensor * ggml_set_2d_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor *  a,
+        struct ggml_tensor *  b,
+        size_t                nb1,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
+}
+
+
 // ggml_cpy
 
 struct ggml_tensor * ggml_cpy_impl(
@@ -4517,7 +5177,6 @@ struct ggml_tensor * ggml_cpy_impl(
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -4555,7 +5214,6 @@ struct ggml_tensor * ggml_cont_impl(
     bool is_node = false;
 
     if (!inplace && a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -4593,11 +5251,15 @@ struct ggml_tensor * ggml_reshape(
 
     bool is_node = false;
 
-    if (a->grad || b->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
+    if (a->grad) {
         is_node = true;
     }
 
+    if (b->grad) {
+        // gradient propagation is not supported
+        //GGML_ASSERT(false);
+    }
+
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
 
     result->op   = GGML_OP_RESHAPE;
@@ -4608,6 +5270,30 @@ struct ggml_tensor * ggml_reshape(
     return result;
 }
 
+struct ggml_tensor * ggml_reshape_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    const int64_t ne[1] = { ne0 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a->data);
+
+    result->op   = GGML_OP_RESHAPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
 struct ggml_tensor * ggml_reshape_2d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -4619,7 +5305,6 @@ struct ggml_tensor * ggml_reshape_2d(
     bool is_node = false;
 
     if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -4646,7 +5331,6 @@ struct ggml_tensor * ggml_reshape_3d(
     bool is_node = false;
 
     if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -4661,6 +5345,34 @@ struct ggml_tensor * ggml_reshape_3d(
     return result;
 }
 
+
+struct ggml_tensor * ggml_reshape_4d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        int64_t               ne3) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
+
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a->data);
+
+    result->op   = GGML_OP_RESHAPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
 // ggml_view_1d
 
 struct ggml_tensor * ggml_view_1d(
@@ -4668,16 +5380,23 @@ struct ggml_tensor * ggml_view_1d(
         struct ggml_tensor  * a,
         int64_t               ne0,
         size_t                offset) {
+
+    bool is_node = false;
+
     if (a->grad) {
-        GGML_ASSERT(false); // gradient propagation is not supported
+        is_node = true;
     }
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
 
     result->op   = GGML_OP_VIEW;
-    result->grad = NULL;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
-    result->src1 = NULL; // TODO: maybe store the offset here?
+    result->src1 = NULL;
+
+    if (is_node) {
+        memcpy(result->padding, &offset, sizeof(offset));
+    }
 
     return result;
 }
@@ -4691,8 +5410,11 @@ struct ggml_tensor * ggml_view_2d(
         int64_t               ne1,
         size_t                nb1,
         size_t                offset) {
+
+    bool is_node = false;
+
     if (a->grad) {
-        GGML_ASSERT(false); // gradient propagation is not supported
+        is_node = true;
     }
 
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
@@ -4704,9 +5426,13 @@ struct ggml_tensor * ggml_view_2d(
     result->nb[3] = result->nb[2];
 
     result->op   = GGML_OP_VIEW;
-    result->grad = NULL;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
-    result->src1 = NULL; // TODO: maybe store the offset here?
+    result->src1 = NULL;
+
+    if (is_node) {
+        memcpy(result->padding, &offset, sizeof(offset));
+    }
 
     return result;
 }
@@ -4722,8 +5448,11 @@ struct ggml_tensor * ggml_view_3d(
         size_t                nb1,
         size_t                nb2,
         size_t                offset) {
+
+    bool is_node = false;
+
     if (a->grad) {
-        GGML_ASSERT(false); // gradient propagation is not supported
+        is_node = true;
     }
 
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
@@ -4735,9 +5464,53 @@ struct ggml_tensor * ggml_view_3d(
     result->nb[3] = result->nb[2]*ne2;
 
     result->op   = GGML_OP_VIEW;
-    result->grad = NULL;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
-    result->src1 = NULL; // TODO: maybe store the offset here?
+    result->src1 = NULL;
+
+    if (is_node) {
+        memcpy(result->padding, &offset, sizeof(offset));
+    }
+
+    return result;
+}
+
+// ggml_view_4d
+
+struct ggml_tensor * ggml_view_4d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        int64_t               ne3,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
+
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 };
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
+
+    result->nb[1] = nb1;
+    result->nb[2] = nb2;
+    result->nb[3] = nb3;
+
+    result->op   = GGML_OP_VIEW;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    if (is_node) {
+        memcpy(result->padding, &offset, sizeof(offset));
+    }
 
     return result;
 }
@@ -4766,7 +5539,6 @@ struct ggml_tensor * ggml_permute(
     bool is_node = false;
 
     if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -4798,7 +5570,14 @@ struct ggml_tensor * ggml_permute(
     result->op   = GGML_OP_PERMUTE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
-    result->src1 = NULL; // TODO: maybe store the permutation here?
+    result->src1 = NULL;
+
+    if (is_node) {
+        result->padding[0] = axis0;
+        result->padding[1] = axis1;
+        result->padding[2] = axis2;
+        result->padding[3] = axis3;
+    }
 
     return result;
 }
@@ -4811,7 +5590,6 @@ struct ggml_tensor * ggml_transpose(
     bool is_node = false;
 
     if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -4842,7 +5620,6 @@ struct ggml_tensor * ggml_get_rows(
     bool is_node = false;
 
     if (a->grad || b->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -4858,24 +5635,82 @@ struct ggml_tensor * ggml_get_rows(
     return result;
 }
 
-// ggml_diag_mask_inf
+// ggml_get_rows_back
 
-struct ggml_tensor * ggml_diag_mask_inf(
+struct ggml_tensor * ggml_get_rows_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   n_past) {
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c) {
+    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0]));
+
     bool is_node = false;
 
-    if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
+    if (a->grad || b->grad) {
         is_node = true;
     }
 
-    // TODO: when implement backward, fix this:
-    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
-    struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
-    ggml_set_name(b, "n_past");
+    // TODO: implement non F32 return
+    //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
+    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]);
+
+    result->op   = GGML_OP_GET_ROWS_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+    result->opt[0] = c;
+
+    return result;
+}
+
+// ggml_diag
+
+struct ggml_tensor * ggml_diag(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    GGML_ASSERT(a->ne[1] == 1);
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne);
+
+    result->op   = GGML_OP_DIAG;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+
+// ggml_diag_mask_inf
+
+struct ggml_tensor * ggml_diag_mask_inf_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        bool                  inplace) {
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+
+    ((int32_t *) b->data)[0] = n_past;
+    ((int32_t *) b->data)[1] = inplace ? 1 : 0;
+
+    ggml_scratch_load(ctx);
 
     result->op   = GGML_OP_DIAG_MASK_INF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4885,21 +5720,81 @@ struct ggml_tensor * ggml_diag_mask_inf(
     return result;
 }
 
-// ggml_soft_max
-
-struct ggml_tensor * ggml_soft_max(
+struct ggml_tensor * ggml_diag_mask_inf(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    return ggml_diag_mask_inf_impl(ctx, a, n_past, false);
+}
+
+
+struct ggml_tensor * ggml_diag_mask_inf_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    return ggml_diag_mask_inf_impl(ctx, a, n_past, true);
+}
+
+// ggml_diag_mask_zero
+
+struct ggml_tensor * ggml_diag_mask_zero_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        bool                  inplace) {
     bool is_node = false;
 
     if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
-    // TODO: when implement backward, fix this:
-    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    ggml_set_name(b, "n_past, inplace");
+
+    ((int32_t *) b->data)[0] = n_past;
+    ((int32_t *) b->data)[1] = inplace ? 1 : 0;
+
+    ggml_scratch_load(ctx);
+
+    result->op   = GGML_OP_DIAG_MASK_ZERO;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_diag_mask_zero(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    return ggml_diag_mask_zero_impl(ctx, a, n_past, false);
+}
+
+struct ggml_tensor * ggml_diag_mask_zero_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    return ggml_diag_mask_zero_impl(ctx, a, n_past, true);
+}
+
+// ggml_soft_max
+
+struct ggml_tensor * ggml_soft_max_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool                  inplace) {
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     result->op   = GGML_OP_SOFT_MAX;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4909,14 +5804,80 @@ struct ggml_tensor * ggml_soft_max(
     return result;
 }
 
+struct ggml_tensor * ggml_soft_max(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_soft_max_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_soft_max_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_soft_max_impl(ctx, a, true);
+}
+
 // ggml_rope
 
+struct ggml_tensor * ggml_rope_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode,
+        bool                  inplace) {
+    GGML_ASSERT(n_past >= 0);
+    bool is_node = false;
+
+    if (!inplace && a->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+
+    ((int32_t *) b->data)[0] = n_past;
+    ((int32_t *) b->data)[1] = n_dims;
+    ((int32_t *) b->data)[2] = mode;
+
+    ggml_scratch_load(ctx);
+
+    result->op   = GGML_OP_ROPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 struct ggml_tensor * ggml_rope(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         int                   n_past,
         int                   n_dims,
         int                   mode) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
+}
+
+struct ggml_tensor * ggml_rope_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
+}
+
+// ggml_rope_back
+
+struct ggml_tensor * ggml_rope_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode) {
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
 
@@ -4925,9 +5886,9 @@ struct ggml_tensor * ggml_rope(
         is_node = true;
     }
 
-    // TODO: when implement backward, fix this:
-    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
 
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
     ((int32_t *) b->data)[0] = n_past;
@@ -4935,7 +5896,9 @@ struct ggml_tensor * ggml_rope(
     ((int32_t *) b->data)[2] = mode;
     ggml_set_name(b, "n_past, n_dims, mode");
 
-    result->op   = GGML_OP_ROPE;
+    ggml_scratch_load(ctx);
+
+    result->op   = GGML_OP_ROPE_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = b;
@@ -4962,10 +5925,15 @@ struct ggml_tensor * ggml_alibi(
     //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_head;
 
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_ALIBI;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
@@ -5189,6 +6157,38 @@ void ggml_set_param(
 
 // ggml_compute_forward_dup
 
+static void ggml_compute_forward_dup_same_cont(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+    GGML_ASSERT(src0->type == dst->type);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb0 = dst->nb[0];
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by elements
+    const int ne = ggml_nelements(dst);
+    const int dr = (ne + nth - 1) / nth;
+    const int ie0 = dr * ith;
+    const int ie1 = MIN(ie0 + dr, ne);
+
+    if (ie0 < ie1) {
+        memcpy(
+            ((char *)  dst->data + ie0*nb0),
+            ((char *) src0->data + ie0*nb00),
+            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
+    }
+
+}
 static void ggml_compute_forward_dup_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -5223,17 +6223,7 @@ static void ggml_compute_forward_dup_f16(
     const int nth = params->nth; // number of threads
 
     if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
-        // parallelize by elements
-        const int ne = ggml_nelements(dst);
-        const int dr = (ne + nth - 1) / nth;
-        const int ie0 = dr * ith;
-        const int ie1 = MIN(ie0 + dr, ne);
-
-        memcpy(
-            ((char *)  dst->data + ie0*nb0),
-            ((char *) src0->data + ie0*nb00),
-            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
-
+        ggml_compute_forward_dup_same_cont(params, src0, dst);
         return;
     }
 
@@ -5522,17 +6512,7 @@ static void ggml_compute_forward_dup_f32(
     const int nth = params->nth; // number of threads
 
     if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
-        // parallelize by elements
-        const int ne = ggml_nelements(dst);
-        const int dr = (ne + nth - 1) / nth;
-        const int ie0 = dr * ith;
-        const int ie1 = MIN(ie0 + dr, ne);
-
-        memcpy(
-            ((char *)  dst->data + ie0*nb0),
-            ((char *) src0->data + ie0*nb00),
-            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
-
+        ggml_compute_forward_dup_same_cont(params, src0, dst);
         return;
     }
 
@@ -5787,6 +6767,10 @@ static void ggml_compute_forward_dup(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
+        ggml_compute_forward_dup_same_cont(params, src0, dst);
+        return;
+    }
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
@@ -5819,44 +6803,73 @@ static void ggml_compute_forward_add_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
 
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
     const size_t nb10 = src1->nb[0];
     const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
 
     const size_t nb0 = dst->nb[0];
     const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
 
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
     if (nb10 == sizeof(float)) {
-        for (int j = ith; j < n; j += nth) {
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+
 #ifdef GGML_USE_ACCELERATE
             vDSP_vadd(
-                    (float *) ((char *) src0->data + j*nb01), 1,
-                    (float *) ((char *) src1->data + j*nb11), 1,
-                    (float *) ((char *) dst->data  + j*nb1),  1, nc);
+                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
+                    ne0);
 #else
-            ggml_vec_add_f32(nc,
-                    (float *) ((char *) dst->data  + j*nb1),
-                    (float *) ((char *) src0->data + j*nb01),
-                    (float *) ((char *) src1->data + j*nb11));
+            ggml_vec_add_f32(ne0,
+                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
 #endif
+                // }
+            // }
         }
     } else {
         // src1 is not contiguous
-        for (int j = ith; j < n; j += nth) {
-            float * dst_ptr  = (float *) ((char *) dst->data  + j*nb1);
-            float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
-            for (int i = 0; i < nc; i++) {
-                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
 
-                dst_ptr[i] = src0_ptr[i] + *src1_ptr;
+            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            for (int i0 = 0; i0 < ne0; i0++) {
+                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
             }
         }
     }
@@ -5876,17 +6889,25 @@ static void ggml_compute_forward_add_f16_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
 
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
     const size_t nb10 = src1->nb[0];
     const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
 
     const size_t nb0 = dst->nb[0];
     const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -5895,13 +6916,26 @@ static void ggml_compute_forward_add_f16_f32(
     GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
     if (nb10 == sizeof(float)) {
-        for (int j = ith; j < n; j += nth) {
-            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + j*nb1);
-            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
-            for (int i = 0; i < nc; i++) {
-                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
-                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+            for (int i = 0; i < ne0; i++) {
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
             }
         }
     }
@@ -5925,32 +6959,53 @@ static void ggml_compute_forward_add_f16_f16(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
 
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
     const size_t nb10 = src1->nb[0];
     const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
 
     const size_t nb0 = dst->nb[0];
     const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
 
     GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
     if (nb10 == sizeof(ggml_fp16_t)) {
-        for (int j = ith; j < n; j += nth) {
-            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + j*nb1);
-            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
-            for (int i = 0; i < nc; i++) {
-                ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
-                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+            for (int i = 0; i < ne0; i++) {
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
             }
         }
     }
@@ -5971,50 +7026,36 @@ static void ggml_compute_forward_add_q_f32(
         return;
     }
 
+    const int nr  = ggml_nrows(src0);
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
+    //const int64_t ne03 = src0->ne[3];
 
-    //const int64_t ne10 = src1->ne[0];
-    //const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
-    //const int64_t ne0  = dst->ne[0];
-    //const int64_t ne1  = dst->ne[1];
-    const int64_t ne2  = dst->ne[2];
-    const int64_t ne3  = dst->ne[3];
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
 
-    const int nb00 = src0->nb[0];
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    const int nb13 = src1->nb[3];
-
-    const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    const size_t nb0  = dst->nb[0];
+    const size_t nb1  = dst->nb[1];
+    const size_t nb2  = dst->nb[2];
+    const size_t nb3  = dst->nb[3];
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
-    GGML_ASSERT(ne3  == ne13);
-
     const enum ggml_type type = src0->type;
     dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
     quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
 
     // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
     GGML_ASSERT(nb10 == sizeof(float));
 
     // dst cannot be transposed or permuted
@@ -6026,9 +7067,6 @@ static void ggml_compute_forward_add_q_f32(
     GGML_ASSERT(dst->type == src0->type);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
-    // total rows in src0
-    const int nr = ne01*ne02*ne03;
-
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
 
@@ -6111,6 +7149,428 @@ static void ggml_compute_forward_add(
     }
 }
 
+// ggml_compute_forward_add1
+
+static void ggml_compute_forward_add1_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+#ifdef GGML_USE_ACCELERATE
+        UNUSED(ggml_vec_add1_f32);
+
+        vDSP_vadd(
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+                (float *) ((char *) src1->data), 0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
+                ne0);
+#else
+        ggml_vec_add1_f32(ne0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+               *(float *) src1->data);
+#endif
+    }
+}
+
+static void ggml_compute_forward_add1_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_f16_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // scalar to add
+    const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_q_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    const enum ggml_type type = src0->type;
+    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
+    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+
+    // we don't support permuted src0
+    GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ggml_is_quantized(src0->type));
+    GGML_ASSERT(dst->type == src0->type);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
+        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
+
+        assert(ne0 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne0);
+        // add src1
+        ggml_vec_acc1_f32(ne0, wdata, v);
+        // quantize row to dst
+        quantize_row_q(wdata, dst_row, ne0);
+    }
+}
+
+static void ggml_compute_forward_add1(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add1_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                if (src1->type == GGML_TYPE_F16) {
+                    ggml_compute_forward_add1_f16_f16(params, src0, src1, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add1_f16_f32(params, src0, src1, dst);
+                }
+                else {
+                    GGML_ASSERT(false);
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+            {
+                ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+
+// ggml_compute_forward_acc
+
+static void ggml_compute_forward_acc_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    GGML_ASSERT(opt0->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(opt0) == 5);
+
+    // view src0 and dst with these strides and data offset inbytes during acc
+    // nb0 is implicitely element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) opt0->data)[0];
+    size_t nb2     = ((int32_t *) opt0->data)[1];
+    size_t nb3     = ((int32_t *) opt0->data)[2];
+    size_t offset  = ((int32_t *) opt0->data)[3];
+    bool   inplace = (bool) ((int32_t *) opt0->data)[4];
+
+    if (!inplace && (params->type == GGML_TASK_INIT)) {
+        // memcpy needs to be synchronized across threads to avoid race conditions.
+        // => do it in INIT phase
+        memcpy(
+            ((char *)  dst->data),
+            ((char *) src0->data),
+            ggml_nbytes(dst));
+    }
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
+
+    // src0 and dst as viewed during acc
+    const size_t nb0 = ggml_element_size(src0);
+
+    const size_t nb00 = nb0;
+    const size_t nb01 = nb1;
+    const size_t nb02 = nb2;
+    const size_t nb03 = nb3;
+
+    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
+    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+#ifdef GGML_USE_ACCELERATE
+        vDSP_vadd(
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
+#else
+        ggml_vec_add_f32(nc,
+                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+#endif
+    }
+}
+
+static void ggml_compute_forward_acc(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_acc_f32(params, src0, src1, opt0, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_sub
 
 static void ggml_compute_forward_sub_f32(
@@ -6125,18 +7585,68 @@ static void ggml_compute_forward_sub_f32(
         return;
     }
 
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
 
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-    assert(src1->nb[0] == sizeof(float));
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sub_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])),
-                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        for (int ir = 0; ir < nr; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+
+#ifdef GGML_USE_ACCELERATE
+            vDSP_vsub(
+                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
+                    ne0);
+#else
+            ggml_vec_sub_f32(ne0,
+                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+#endif
+                // }
+            // }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int ir = 0; ir < nr; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            for (int i0 = 0; i0 < ne0; i0++) {
+                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
+            }
+        }
     }
 }
 
@@ -6164,25 +7674,78 @@ static void ggml_compute_forward_mul_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
     assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
+    const int ith = params->ith;
+    const int nth = params->nth;
 
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
 
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-    assert(src1->nb[0] == sizeof(float));
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
-    for (int i = 0; i < n; i++) {
-        ggml_vec_mul_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])),
-                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        for (int ir = ith; ir < nr; ir += nth) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+
+#ifdef GGML_USE_ACCELERATE
+            UNUSED(ggml_vec_mul_f32);
+
+            vDSP_vmul(
+                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
+                    ne0);
+#else
+            ggml_vec_mul_f32(ne0,
+                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+#endif
+                // }
+            // }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int ir = ith; ir < nr; ir += nth) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            for (int i0 = 0; i0 < ne0; i0++) {
+                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
+            }
+        }
     }
 }
 
@@ -6217,18 +7780,68 @@ static void ggml_compute_forward_div_f32(
         return;
     }
 
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr  = ggml_nrows(src0);
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
 
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-    assert(src1->nb[0] == sizeof(float));
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
-    for (int i = 0; i < n; i++) {
-        ggml_vec_div_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])),
-                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        for (int ir = 0; ir < nr; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+
+#ifdef GGML_USE_ACCELERATE
+            vDSP_vdiv(
+                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
+                    ne0);
+#else
+            ggml_vec_div_f32(ne0,
+                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+#endif
+                // }
+            // }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int ir = 0; ir < nr; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            for (int i0 = 0; i0 < ne0; i0++) {
+                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
+            }
+        }
     }
 }
 
@@ -6333,6 +7946,49 @@ static void ggml_compute_forward_sqrt(
     }
 }
 
+
+// ggml_compute_forward_log
+
+static void ggml_compute_forward_log_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_log_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_log(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_log_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_sum
 
 static void ggml_compute_forward_sum_f32(
@@ -6390,6 +8046,73 @@ static void ggml_compute_forward_sum(
     }
 }
 
+// ggml_compute_forward_sum_rows
+
+static void ggml_compute_forward_sum_rows_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    GGML_ASSERT(ne0 == 1);
+    GGML_ASSERT(ne1 == ne01);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    for (int64_t i3 = 0; i3 < ne03; i3++) {
+        for (int64_t i2 = 0; i2 < ne02; i2++) {
+            for (int64_t i1 = 0; i1 < ne01; i1++) {
+                float* src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
+                float* dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
+                float row_sum = 0;
+                ggml_vec_sum_f32(ne00, &row_sum, src_row);
+                dst_row[0] = row_sum;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_sum_rows(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sum_rows_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_mean
 
 static void ggml_compute_forward_mean_f32(
@@ -6467,37 +8190,58 @@ static void ggml_compute_forward_repeat_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-    assert(ggml_can_repeat(src0, dst));
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    // TODO: implement support for rank > 2 tensors
-    assert(src0->ne[2] == 1);
-    assert(src0->ne[3] == 1);
-    assert( dst->ne[2] == 1);
-    assert( dst->ne[3] == 1);
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
 
-    const int nc  = dst->ne[0];
-    const int nr  = dst->ne[1];
-    const int nc0 = src0->ne[0];
-    const int nr0 = src0->ne[1];
-    const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const size_t nb0  = dst->nb[0];
+    const size_t nb1  = dst->nb[1];
+    const size_t nb2  = dst->nb[2];
+    const size_t nb3  = dst->nb[3];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne0/ne00);
+    const int nr1 = (int)(ne1/ne01);
+    const int nr2 = (int)(ne2/ne02);
+    const int nr3 = (int)(ne3/ne03);
 
     // TODO: support for transposed / permuted tensors
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
 
     // TODO: maybe this is not optimal?
-    for (int i = 0; i < nrr; i++) {
-        for (int j = 0; j < ncr; j++) {
-            for (int k = 0; k < nr0; k++) {
-                ggml_vec_cpy_f32(nc0,
-                        (float *) ((char *)  dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])),
-                        (float *) ((char *) src0->data + (        k)*(src0->nb[1])));
+    for                         (int i3 = 0; i3 < nr3;  i3++) {
+        for                     (int k3 = 0; k3 < ne03; k3++) {
+            for                 (int i2 = 0; i2 < nr2;  i2++) {
+                for             (int k2 = 0; k2 < ne02; k2++) {
+                    for         (int i1 = 0; i1 < nr1;  i1++) {
+                        for     (int k1 = 0; k1 < ne01; k1++) {
+                            for (int i0 = 0; i0 < nr0;  i0++) {
+                                ggml_vec_cpy_f32(ne00,
+                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
+                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
+                            }
+                        }
+                    }
+                }
             }
         }
     }
@@ -6850,6 +8594,70 @@ static void ggml_compute_forward_silu(
 }
 
 
+// ggml_compute_forward_silu_back
+
+static void ggml_compute_forward_silu_back_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * grad,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(grad));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, grad));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_backward_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])),
+                (float *) ((char *) grad->data + i1*(grad->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * grad,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_back_f32(params, src0, grad, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_norm
 
 static void ggml_compute_forward_norm_f32(
@@ -7004,6 +8812,195 @@ static void ggml_compute_forward_rms_norm(
 }
 
 
+static void ggml_compute_forward_rms_norm_back_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
+
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    const float eps = 1e-6f; // TODO: make this a parameter
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                // src1 is same shape as src0 => same indices
+                const int64_t i11 = i01;
+                const int64_t i12 = i02;
+                const int64_t i13 = i03;
+
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
+
+                ggml_float sum_xx  = 0.0;
+                ggml_float sum_xdz = 0.0;
+
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
+                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
+                }
+
+                //const float mean     = (float)(sum_xx)/ne00;
+                const float mean_eps = (float)(sum_xx)/ne00 + eps;
+                const float sum_eps  = (float)(sum_xx) + eps*ne00;
+                //const float mean_xdz = (float)(sum_xdz)/ne00;
+                // we could cache rms from forward pass to improve performance.
+                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
+                //const float rms      = sqrtf(mean_eps);
+                const float rrms     = 1.0f / sqrtf(mean_eps);
+                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
+
+                {
+                    // z = rms_norm(x)
+                    //
+                    // rms_norm(src0) =
+                    //     scale(
+                    //         src0,
+                    //         div(
+                    //             1,
+                    //             sqrt(
+                    //                 add(
+                    //                     scale(
+                    //                         sum(
+                    //                             sqr(
+                    //                                 src0)),
+                    //                         (1.0/N)),
+                    //                     eps))));
+
+                    // postorder:
+                    // ## op    args         grad
+                    // 00 param src0         grad[#00]
+                    // 01 const 1
+                    // 02 sqr   (#00)        grad[#02]
+                    // 03 sum   (#02)        grad[#03]
+                    // 04 const 1/N
+                    // 05 scale (#03, #04)   grad[#05]
+                    // 06 const eps
+                    // 07 add   (#05, #06)   grad[#07]
+                    // 08 sqrt  (#07)        grad[#08]
+                    // 09 div   (#01,#08)    grad[#09]
+                    // 10 scale (#00,#09)    grad[#10]
+                    //
+                    // backward pass, given grad[#10]
+                    // #10: scale
+                    // grad[#00] += scale(grad[#10],#09)
+                    // grad[#09] += sum(mul(grad[#10],#00))
+                    // #09: div
+                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
+                    // #08: sqrt
+                    // grad[#07] += mul(grad[#08], div(0.5, #08))
+                    // #07: add
+                    // grad[#05] += grad[#07]
+                    // #05: scale
+                    // grad[#03] += scale(grad[#05],#04)
+                    // #03: sum
+                    // grad[#02] += repeat(grad[#03], #02)
+                    // #02:
+                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
+                    //
+                    // substitute and simplify:
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+                    // grad[#02] = repeat(grad[#03], #02)
+                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
+                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
+                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
+                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
+                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
+                    // a = b*c + d*e
+                    // a = b*c*f/f + d*e*f/f
+                    // a = (b*c*f + d*e*f)*(1/f)
+                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
+                    // a = (b + d*e/c)*c
+                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
+                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
+                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
+                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
+                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
+                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
+                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
+                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
+                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                }
+                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                // post-order:
+                // dx := x
+                // dx := scale(dx,-mean_xdz/mean_eps)
+                // dx := add(dx, dz)
+                // dx := scale(dx, rrms)
+                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                ggml_vec_cpy_f32  (ne00, dx, x);
+                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
+                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
+                ggml_vec_acc_f32  (ne00, dx, dz);
+                ggml_vec_scale_f32(ne00, dx, rrms);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rms_norm_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_back_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+
 // ggml_compute_forward_mul_mat
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
@@ -7722,8 +9719,17 @@ static void ggml_compute_forward_scale_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb1 = dst->nb[1];
+
+
     for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), v);
+        if (dst->data != src0->data) {
+            // src0 is same shape as dst => same indices
+            memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
+        }
+        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
     }
 }
 
@@ -7744,6 +9750,115 @@ static void ggml_compute_forward_scale(
     }
 }
 
+// ggml_compute_forward_set
+
+static void ggml_compute_forward_set_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    GGML_ASSERT(opt0->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(opt0) == 5);
+
+    // view src0 and dst with these strides and data offset inbytes during set
+    // nb0 is implicitely element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) opt0->data)[0];
+    size_t nb2     = ((int32_t *) opt0->data)[1];
+    size_t nb3     = ((int32_t *) opt0->data)[2];
+    size_t offset  = ((int32_t *) opt0->data)[3];
+    bool   inplace = (bool) ((int32_t *) opt0->data)[4];
+
+    if (!inplace && (params->type == GGML_TASK_INIT)) {
+        // memcpy needs to be synchronized across threads to avoid race conditions.
+        // => do it in INIT phase
+        memcpy(
+            ((char *)  dst->data),
+            ((char *) src0->data),
+            ggml_nbytes(dst));
+    }
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+    const size_t nb12 = src1->nb[2];
+    const size_t nb13 = src1->nb[3];
+
+    // src0 and dst as viewed during set
+    const size_t nb0 = ggml_element_size(src0);
+
+    const int im0 = (ne10 == 0 ? 0 : ne10-1);
+    const int im1 = (ne11 == 0 ? 0 : ne11-1);
+    const int im2 = (ne12 == 0 ? 0 : ne12-1);
+    const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  < ggml_nbytes(dst));
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+    }
+}
+
+static void ggml_compute_forward_set(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_set_f32(params, src0, src1, opt0, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_cpy
 
 static void ggml_compute_forward_cpy(
@@ -7939,22 +10054,220 @@ static void ggml_compute_forward_get_rows(
     //}
 }
 
-// ggml_compute_forward_diag_mask_inf
+// ggml_compute_forward_get_rows_back
 
-static void ggml_compute_forward_diag_mask_inf_f32(
+static void ggml_compute_forward_get_rows_back_f32_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 1);
+        const struct ggml_tensor * opt0,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_are_same_shape(opt0, dst));
+    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    ggml_compute_forward_dup_same_cont(params, opt0, dst);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int n_past = ((int32_t *) src1->data)[0];
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    GGML_ASSERT( dst->ne[0] == nc);
+    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        for (int j = 0; j < nc; ++j) {
+            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
+            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
+        }
+    }
+}
+
+static void ggml_compute_forward_get_rows_back_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
+              struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_are_same_shape(opt0, dst));
+    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    GGML_ASSERT( dst->ne[0] == nc);
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        ggml_vec_add_f32(nc,
+                (float *) ((char *)  dst->data + r*dst->nb[1]),
+                (float *) ((char *)  dst->data + r*dst->nb[1]),
+                (float *) ((char *) src0->data + i*src0->nb[1]));
+    }
+}
+
+
+static void ggml_compute_forward_get_rows_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
+}
+
+// ggml_compute_forward_diag
+
+static void ggml_compute_forward_diag_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+    const int ne2 = dst->ne[2];
+    const int ne3 = dst->ne[3];
+    GGML_ASSERT(ne00 == ne0);
+    GGML_ASSERT(ne00 == ne1);
+    GGML_ASSERT(ne01 == 1);
+    GGML_ASSERT(ne02 == ne2);
+    GGML_ASSERT(ne03 == ne3);
+
+    const int nb00 = src0->nb[0];
+    //const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+    const int nb0 = dst->nb[0];
+    const int nb1 = dst->nb[1];
+    const int nb2 = dst->nb[2];
+    const int nb3 = dst->nb[3];
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb0  == sizeof(float));
+
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = 0; i2 < ne2; i2++) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
+                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
+                for (int i0 = 0; i0 < i1; i0++) {
+                    d[i0] = 0;
+                }
+                d[i1] = s[i1];
+                for (int i0 = i1+1; i0 < ne0; i0++) {
+                    d[i0] = 0;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_diag(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_diag_mask_inf
+
+static void ggml_compute_forward_diag_mask_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst,
+        const float value) {
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 2);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int  n_past  =       ((int32_t *) src1->data)[0];
+    const bool inplace = (bool)((int32_t *) src1->data)[1];
+    assert(n_past >= 0);
+
+    if (!inplace && (params->type == GGML_TASK_INIT)) {
+        // memcpy needs to be synchronized across threads to avoid race conditions.
+        // => do it in INIT phase
+        GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+        GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+        memcpy(
+            ((char *)  dst->data),
+            ((char *) src0->data),
+            ggml_nbytes(dst));
+    }
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
 
     // TODO: handle transposed/permuted matrices
 
@@ -7967,10 +10280,10 @@ static void ggml_compute_forward_diag_mask_inf_f32(
     assert(src0->nb[0] == sizeof(float));
 
     for (int k = 0; k < nz; k++) {
-        for (int j = 0; j < nr; j++) {
+        for (int j = ith; j < nr; j += nth) {
             for (int i = n_past; i < nc; i++) {
                 if (i > n_past + j) {
-                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY;
+                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
                 }
             }
         }
@@ -7985,7 +10298,24 @@ static void ggml_compute_forward_diag_mask_inf(
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
+                ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, -INFINITY);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+static void ggml_compute_forward_diag_mask_zero(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, 0);
             } break;
         default:
             {
@@ -8024,44 +10354,44 @@ static void ggml_compute_forward_soft_max_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     for (int i1 = ir0; i1 < ir1; i1++) {
-        float *p = (float *)((char *) dst->data + i1*dst->nb[1]);
+        float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float *dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
             //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(p[i]));
+            assert(!isnan(sp[i]));
         }
 #endif
 
         float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, p);
+        ggml_vec_max_f32(nc, &max, sp);
 
         ggml_float sum = 0.0;
 
         uint16_t scvt;
         for (int i = 0; i < nc; i++) {
-            //printf("p[%3d] = %8.4f\n", i, p[i]);
-            if (p[i] == -INFINITY) {
-                p[i] = 0.0f;
+            if (sp[i] == -INFINITY) {
+                dp[i] = 0.0f;
             } else {
-                //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
-                ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
+                // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
+                ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
                 memcpy(&scvt, &s, sizeof(scvt));
                 const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
                 sum += (ggml_float)val;
-                p[i] = val;
+                dp[i] = val;
             }
         }
 
         assert(sum > 0.0);
 
         sum = 1.0/sum;
-        ggml_vec_scale_f32(nc, p, sum);
+        ggml_vec_scale_f32(nc, dp, sum);
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
-            assert(!isnan(p[i]));
-            assert(!isinf(p[i]));
+            assert(!isnan(dp[i]));
+            assert(!isinf(dp[i]));
         }
 #endif
     }
@@ -8101,6 +10431,8 @@ static void ggml_compute_forward_alibi_f32(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_head = ((int32_t *) src1->data)[1];
 
+    assert(n_past >= 0);
+
     const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int ne1 = src0->ne[1]; // seq_len_without_past
     //const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -8139,7 +10471,7 @@ static void ggml_compute_forward_alibi_f32(
                     m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
                 }
 
-                pdst[0] = (j+1) * m_k + src[0];
+                pdst[0] = i * m_k + src[0];
             }
         }
     }
@@ -8162,6 +10494,8 @@ static void ggml_compute_forward_alibi_f16(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_head = ((int32_t *) src1->data)[1];
 
+    assert(n_past >= 0);
+
     const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int ne1 = src0->ne[1]; // seq_len_without_past
     //const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -8201,7 +10535,7 @@ static void ggml_compute_forward_alibi_f16(
                 }
 
                 // we return F32
-                pdst[0] = (j+1) * m_k + GGML_FP16_TO_FP32(src[0]);
+                pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
             }
         }
     }
@@ -8245,8 +10579,8 @@ static void ggml_compute_forward_rope_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 3);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -8256,25 +10590,35 @@ static void ggml_compute_forward_rope_f32(
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    assert(n_past >= 0);
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
 
-    assert(nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
+    const int nr = ggml_nrows(dst);
+
+    GGML_ASSERT(n_dims <= ne0);
+    GGML_ASSERT(n_dims % 2 == 0);
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -8292,37 +10636,50 @@ static void ggml_compute_forward_rope_f32(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = src[0];
                         const float x1 = src[1];
 
                         dst_data[0] = x0*cos_theta - x1*sin_theta;
                         dst_data[1] = x0*sin_theta + x1*cos_theta;
-                    } else {
-                        const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float x0 = src[0];
-                        const float x1 = src[n_dims/2];
+                            theta *= theta_scale;
 
-                        dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                        dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                            const int64_t i0 = ib*n_dims + ic/2;
+
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float x0 = src[0];
+                            const float x1 = src[n_dims/2];
+
+                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                        }
                     }
                 }
             }
@@ -8335,8 +10692,8 @@ static void ggml_compute_forward_rope_f16(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 3);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -8346,25 +10703,35 @@ static void ggml_compute_forward_rope_f16(
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    assert(n_past >= 0);
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
 
-    assert(nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
+    const int nr = ggml_nrows(dst);
+
+    GGML_ASSERT(n_dims <= ne0);
+    GGML_ASSERT(n_dims % 2 == 0);
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -8382,37 +10749,50 @@ static void ggml_compute_forward_rope_f16(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = GGML_FP16_TO_FP32(src[0]);
                         const float x1 = GGML_FP16_TO_FP32(src[1]);
 
                         dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                    } else {
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+                            theta *= theta_scale;
 
-                        dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                            const int64_t i0 = ib*n_dims + ic/2;
+
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float x0 = GGML_FP16_TO_FP32(src[0]);
+                            const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+
+                            dst_data[0]     = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        }
                     }
                 }
             }
@@ -8441,6 +10821,255 @@ static void ggml_compute_forward_rope(
     }
 }
 
+// ggml_compute_forward_rope_back
+
+static void ggml_compute_forward_rope_back_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 3);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // y = rope(x, src1)
+    // dx = rope_back(dy, src1)
+    // src0 is dy, src1 contains options
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+
+    assert(n_past >= 0);
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    assert(nb0 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(dst);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+
+    const bool is_neox = mode & 2;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
+                float theta = (float)p;
+
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
+
+                        theta *= theta_scale;
+
+                        const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float dy0 = dy[0];
+                        const float dy1 = dy[1];
+
+                        dx[0] =   dy0*cos_theta + dy1*sin_theta;
+                        dx[1] = - dy0*sin_theta + dy1*cos_theta;
+                    }
+                } else {
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
+
+                            theta *= theta_scale;
+
+                            const int64_t i0 = ib*n_dims + ic/2;
+
+                            const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float dy0 = dy[0];
+                            const float dy1 = dy[n_dims/2];
+
+                            dx[0]        =   dy0*cos_theta + dy1*sin_theta;
+                            dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rope_back_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 3);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // y = rope(x, src1)
+    // dx = rope_back(dy, src1)
+    // src0 is dy, src1 contains options
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+
+    assert(n_past >= 0);
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    assert(nb0 == sizeof(ggml_fp16_t));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(dst);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+
+    const bool is_neox = mode & 2;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
+                float theta = (float)p;
+
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
+
+                        theta *= theta_scale;
+
+                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float dy0 = GGML_FP16_TO_FP32(dy[0]);
+                        const float dy1 = GGML_FP16_TO_FP32(dy[1]);
+
+                        dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
+                        dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
+                    }
+                } else {
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
+
+                            theta *= theta_scale;
+
+                            const int64_t i0 = ib*n_dims + ic/2;
+
+                            const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float dy0 = GGML_FP16_TO_FP32(dy[0]);
+                            const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
+
+                            dx[0]        = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
+                            dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rope_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_conv_1d_1s
 
 static void ggml_compute_forward_conv_1d_1s_f16_f32(
@@ -9760,6 +12389,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_ADD1:
+            {
+                ggml_compute_forward_add1(params, tensor->src0, tensor->src1, tensor);
+            } break;
+        case GGML_OP_ACC:
+            {
+                ggml_compute_forward_acc(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+            } break;
         case GGML_OP_SUB:
             {
                 ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
@@ -9780,10 +12417,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_sqrt(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_LOG:
+            {
+                ggml_compute_forward_log(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_SUM:
             {
                 ggml_compute_forward_sum(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_SUM_ROWS:
+            {
+                ggml_compute_forward_sum_rows(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_MEAN:
             {
                 ggml_compute_forward_mean(params, tensor->src0, tensor);
@@ -9820,6 +12465,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_silu(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_SILU_BACK:
+            {
+                ggml_compute_forward_silu_back(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_NORM:
             {
                 ggml_compute_forward_norm(params, tensor->src0, tensor);
@@ -9828,6 +12477,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_RMS_NORM_BACK:
+            {
+                ggml_compute_forward_rms_norm_back(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
@@ -9836,6 +12489,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_SET:
+            {
+                ggml_compute_forward_set(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+            } break;
         case GGML_OP_CPY:
             {
                 ggml_compute_forward_cpy(params, tensor->src0, tensor);
@@ -9864,10 +12521,22 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_GET_ROWS_BACK:
+            {
+                ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+            } break;
+        case GGML_OP_DIAG:
+            {
+                ggml_compute_forward_diag(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_DIAG_MASK_INF:
             {
                 ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_DIAG_MASK_ZERO:
+            {
+                ggml_compute_forward_diag_mask_zero(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_SOFT_MAX:
             {
                 ggml_compute_forward_soft_max(params, tensor->src0, tensor);
@@ -9876,6 +12545,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_ROPE_BACK:
+            {
+                ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_ALIBI:
             {
                 ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
@@ -9944,6 +12617,48 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
                 }
             } break;
+        case GGML_OP_ADD1:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+                if (src1->grad) {
+                    src1->grad = ggml_add_impl(ctx,
+                        src1->grad,
+                        ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
+                        inplace);
+                }
+            } break;
+        case GGML_OP_ACC:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+                if (src1->grad) {
+                    GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
+                    GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
+                    const size_t nb1     = (( int32_t * ) tensor->opt[0]->data)[0];
+                    const size_t nb2     = (( int32_t * ) tensor->opt[0]->data)[1];
+                    const size_t nb3     = (( int32_t * ) tensor->opt[0]->data)[2];
+                    const size_t offset  = (( int32_t * ) tensor->opt[0]->data)[3];
+
+                    struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
+                        tensor->grad,
+                        src1->grad->ne[0],
+                        src1->grad->ne[1],
+                        src1->grad->ne[2],
+                        src1->grad->ne[3],
+                        nb1, nb2, nb3, offset);
+
+                    src1->grad =
+                        ggml_add_impl(ctx,
+                            src1->grad,
+                            ggml_reshape(ctx,
+                                ggml_cont(ctx, tensor_grad_view),
+                                src1->grad),
+                            inplace);
+                }
+            } break;
         case GGML_OP_SUB:
             {
                 if (src0->grad) {
@@ -9995,31 +12710,57 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad =
                         ggml_add_impl(ctx,
                                 src0->grad,
-                                ggml_mul(ctx,
+                                ggml_scale(ctx,
                                     ggml_mul(ctx, src0, tensor->grad),
-                                    ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)),
+                                    ggml_new_f32(ctx, 2.0f)),
                                 inplace);
                 }
             } break;
         case GGML_OP_SQRT:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_mul(ctx,
+                                    tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
+                                    ggml_div(ctx,
+                                        ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
+                                        tensor)),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_LOG:
             {
                 if (src0->grad) {
                     src0->grad =
                         ggml_add_impl(ctx,
                                 src0->grad,
                                 ggml_div(ctx,
-                                    ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
-                                    tensor),
+                                    tensor->grad,
+                                    src0),
                                 inplace);
                 }
             } break;
         case GGML_OP_SUM:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add1_impl(ctx,
+                                src0->grad,
+                                tensor->grad,
+                                inplace);
+                }
+            } break;
+        case GGML_OP_SUM_ROWS:
             {
                 if (src0->grad) {
                     src0->grad =
                         ggml_add_impl(ctx,
                                 src0->grad,
-                                ggml_repeat(ctx, tensor->grad, src0->grad),
+                                ggml_repeat(ctx,
+                                    tensor->grad,
+                                    src0->grad),
                                 inplace);
                 }
             } break;
@@ -10029,11 +12770,44 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_REPEAT:
             {
+                // necessary for llama
                 if (src0->grad) {
+                    GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
+                    const int nc  = tensor->ne[0];
+                    const int nr  = tensor->ne[1];
+                    const int nc0 = src0->ne[0];
+                    const int nr0 = src0->ne[1];
+                    const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
+                    const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
+                    // tensor->grad [nc,nr,1,1]
+                    // reshape      [nc0,nc/nc0,nr0,nr/nr0]
+                    // permute      [nc0,nr0,nc/nc0,nr/nr0]
+                    // substitute   [nc0,nr0,ncr,nrr]
+                    // reshape      [nc0*nr0,ncr*nrr,1,1]
+                    // transpose    [ncr*nrr,nc0*nr0,1,1]
+                    // sum rows     [1,nc0*nr0,1,1]
+                    // transpose    [nc0*nr0,1,1]
+                    // reshape      [nc0,nr0,1,1] reshape_1d or reshape_2d
+                    // add to src0->grad
+
+                    int64_t ne[4]  = {nc0,ncr,nr0,nrr};
+
+                    struct ggml_tensor* F00 = tensor->grad;
+                    struct ggml_tensor* F01 = ggml_reshape   (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
+                    struct ggml_tensor* F02 = ggml_permute   (ctx, F01, 0,2,1,3);
+                    struct ggml_tensor* F03 = ggml_cont      (ctx, F02);
+                    struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
+                    struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
+                    struct ggml_tensor* F06 = ggml_cont      (ctx, F05);
+                    struct ggml_tensor* F07 = ggml_sum_rows  (ctx, F06);
+                    struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
+                    struct ggml_tensor* F09 = ggml_cont      (ctx, F08);
+                    struct ggml_tensor* F10 = ggml_reshape   (ctx, F09, src0->grad);
+
                     src0->grad =
                         ggml_add_impl(ctx,
                                 src0->grad,
-                                ggml_sum(ctx, tensor->grad),
+                                F10,
                                 inplace);
                 }
             } break;
@@ -10087,6 +12861,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_SILU:
+            {
+                // necessary for llama
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            ggml_silu_back(ctx, src0, tensor->grad),
+                            inplace);
+                }
+            } break;
+        case GGML_OP_SILU_BACK:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
@@ -10095,68 +12879,372 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_RMS_NORM:
+            {
+                // necessary for llama
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            ggml_rms_norm_back(ctx, src0, tensor->grad),
+                            inplace);
+                }
+            } break;
+        case GGML_OP_RMS_NORM_BACK:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_MUL_MAT:
             {
+                // https://cs231n.github.io/optimization-2/#staged
+                // # forward pass
+                // s0 = np.random.randn(5, 10)
+                // s1 = np.random.randn(10, 3)
+                // t = s0.dot(s1)
+
+                // # now suppose we had the gradient on t from above in the circuit
+                // dt = np.random.randn(*t.shape) # same shape as t
+                // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
+                // ds1 = t.T.dot(dt)
+
+                // tensor.shape [m,p]
+                // src0.shape   [n,m]
+                // src1.shape   [n,p]
+
+                // necessary for llama
                 if (src0->grad) {
                     // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
-                    GGML_ASSERT(false);
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                                src0->grad,
+                                // ds0 = dt.dot(s1.T)
+                                // ggml_out_prod(ctx, // [n,m]
+                                //     src1,          // [n,p]
+                                //     tensor->grad), // [m,p]
+                                // for now just using A*B==(B.T*A.T).T
+                                ggml_cont(ctx,                      // [n,m]
+                                    ggml_transpose(ctx,             // [n,m]
+                                        ggml_mul_mat(ctx,           // [m,n]
+                                            ggml_cont(ctx,          // [p,m]
+                                                ggml_transpose(ctx, // [p,m]
+                                                    tensor->grad)), // [m,p]
+                                            ggml_cont(ctx,          // [p,n]
+                                                ggml_transpose(ctx, // [p,n]
+                                                    src1))))),      // [n,p]
+                                inplace);
                 }
                 if (src1->grad) {
                     src1->grad =
                         ggml_add_impl(ctx,
                                 src1->grad,
-                                ggml_mul_mat(ctx,
-                                    ggml_cont(ctx, ggml_transpose(ctx, src0)),
-                                    tensor->grad),
+                                // ds1 = s0.T.dot(dt):
+                                ggml_mul_mat(ctx,                   // [n,p]
+                                    ggml_cont(ctx,                  // [m,n]
+                                        ggml_transpose(ctx, src0)), // [m,n]
+                                    tensor->grad),                  // [m,p]
                                 inplace);
                 }
             } break;
         case GGML_OP_SCALE:
             {
-                GGML_ASSERT(false); // TODO: not implemented
+                // necessary for llama
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                            src0->grad,
+                            ggml_scale_impl(ctx, tensor->grad, src1, false),
+                            inplace);
+                }
+                if (src1->grad) {
+                    src1->grad =
+                        ggml_add_impl(ctx,
+                            src1->grad,
+                            ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
+                            inplace);
+                }
+            } break;
+        case GGML_OP_SET:
+            {
+                GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
+                GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
+                const size_t nb1     = (( int32_t * ) tensor->opt[0]->data)[0];
+                const size_t nb2     = (( int32_t * ) tensor->opt[0]->data)[1];
+                const size_t nb3     = (( int32_t * ) tensor->opt[0]->data)[2];
+                const size_t offset  = (( int32_t * ) tensor->opt[0]->data)[3];
+
+                struct ggml_tensor * tensor_grad_view = NULL;
+
+                if (src0->grad || src1->grad) {
+                    GGML_ASSERT(src0->type == tensor->type);
+                    GGML_ASSERT(tensor->grad->type == tensor->type);
+                    GGML_ASSERT(tensor->grad->type == src1->grad->type);
+
+                    tensor_grad_view = ggml_view_4d(ctx,
+                        tensor->grad,
+                        src1->grad->ne[0],
+                        src1->grad->ne[1],
+                        src1->grad->ne[2],
+                        src1->grad->ne[3],
+                        nb1, nb2, nb3, offset);
+                }
+
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx,
+                        src0->grad,
+                        ggml_acc_impl(ctx,
+                            tensor->grad,
+                            ggml_neg(ctx, tensor_grad_view),
+                            nb1, nb2, nb3, offset, false),
+                        inplace);
+                }
+
+                if (src1->grad) {
+                    src1->grad =
+                        ggml_add_impl(ctx,
+                            src1->grad,
+                            ggml_reshape(ctx,
+                                ggml_cont(ctx, tensor_grad_view),
+                                src1->grad),
+                            inplace);
+                }
             } break;
         case GGML_OP_CPY:
             {
-                GGML_ASSERT(false); // TODO: not implemented
+                // necessary for llama
+                // cpy overwrites value of src1 by src0 and returns view(src1)
+                // the overwriting is mathematically equivalent to:
+                // tensor = src0 * 1 + src1 * 0
+                if (src0->grad) {
+                    // dsrc0 = dtensor * 1
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
+                if (src1->grad) {
+                    // dsrc1 = dtensor * 0 -> noop
+                }
             } break;
         case GGML_OP_CONT:
             {
-                GGML_ASSERT(false); // TODO: not implemented
+                // same as cpy
+                if (src0->grad) {
+                    GGML_ASSERT(ggml_is_contiguous(src0->grad));
+                    GGML_ASSERT(ggml_is_contiguous(tensor->grad));
+                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                }
             } break;
         case GGML_OP_RESHAPE:
             {
-                GGML_ASSERT(false); // TODO: not implemented
+                // necessary for llama
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx, src0->grad,
+                            ggml_reshape(ctx, tensor->grad, src0->grad),
+                        inplace);
+                }
             } break;
         case GGML_OP_VIEW:
             {
-                GGML_ASSERT(false); // not supported
+                // necessary for llama
+                if (src0->grad) {
+                    size_t offset;
+                    memcpy(&offset, tensor->padding, sizeof(offset));
+
+                    size_t nb1     = tensor->nb[1];
+                    size_t nb2     = tensor->nb[2];
+                    size_t nb3     = tensor->nb[3];
+
+                    if (src0->type != src0->grad->type) {
+                        // gradient is typically F32, but src0 could be other type
+                        size_t ng = ggml_element_size(src0->grad);
+                        size_t n0 = ggml_element_size(src0);
+                        GGML_ASSERT(offset % n0 == 0);
+                        GGML_ASSERT(nb1 % n0 == 0);
+                        GGML_ASSERT(nb2 % n0 == 0);
+                        GGML_ASSERT(nb3 % n0 == 0);
+                        offset = (offset / n0) * ng;
+                        nb1 = (nb1 / n0) * ng;
+                        nb2 = (nb2 / n0) * ng;
+                        nb3 = (nb3 / n0) * ng;
+                    }
+
+                    src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
+                }
             } break;
         case GGML_OP_PERMUTE:
             {
-                GGML_ASSERT(false); // TODO: not implemented
+                // necessary for llama
+                if (src0->grad) {
+                    int axis0 = tensor->padding[0] & 0x3;
+                    int axis1 = tensor->padding[1] & 0x3;
+                    int axis2 = tensor->padding[2] & 0x3;
+                    int axis3 = tensor->padding[3] & 0x3;
+                    int axes_backward[4] = {0,0,0,0};
+                    axes_backward[axis0] = 0;
+                    axes_backward[axis1] = 1;
+                    axes_backward[axis2] = 2;
+                    axes_backward[axis3] = 3;
+                    src0->grad =
+                        ggml_add_impl(ctx, src0->grad,
+                            ggml_permute(ctx,
+                                tensor->grad,
+                                axes_backward[0],
+                                axes_backward[1],
+                                axes_backward[2],
+                                axes_backward[3]),
+                            inplace);
+                }
             } break;
         case GGML_OP_TRANSPOSE:
             {
-                GGML_ASSERT(false); // TODO: not implemented
+                // necessary for llama
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx, src0->grad,
+                            ggml_transpose(ctx, tensor->grad),
+                        inplace);
+                }
             } break;
         case GGML_OP_GET_ROWS:
+            {
+                // necessary for llama (only for tokenizer)
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_impl(ctx, src0->grad,
+                            ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
+                        inplace);
+                }
+                if (src1->grad) {
+                    // noop
+                }
+            } break;
+        case GGML_OP_GET_ROWS_BACK:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_DIAG:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_DIAG_MASK_INF:
             {
-                GGML_ASSERT(false); // TODO: not implemented
+                // necessary for llama
+                if (src0->grad) {
+                    assert(src1->type == GGML_TYPE_I32);
+                    assert(ggml_nelements(src1) == 2);
+                    const int n_past = ((int32_t *) src1->data)[0];
+                    src0->grad =
+                        ggml_add_impl(ctx, src0->grad,
+                            ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
+                        inplace);
+                }
+                if (src1->grad) {
+                    // noop
+                }
+            } break;
+        case GGML_OP_DIAG_MASK_ZERO:
+            {
+                // necessary for llama
+                if (src0->grad) {
+                    assert(src1->type == GGML_TYPE_I32);
+                    assert(ggml_nelements(src1) == 2);
+                    const int n_past = ((int32_t *) src1->data)[0];
+                    src0->grad =
+                        ggml_add_impl(ctx, src0->grad,
+                            ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
+                        inplace);
+                }
+                if (src1->grad) {
+                    // noop
+                }
             } break;
         case GGML_OP_SOFT_MAX:
             {
-                GGML_ASSERT(false); // TODO: not implemented
+                // necessary for llama
+                if (src0->grad) {
+                    // y = softmax(x)
+                    //
+                    // Jii = yi - yi*yi
+                    // Jij = -yi*yj
+                    // J = diag(y)-y.*y
+                    // dx = J * dy
+                    // dxk = sum(Jkj * dyk)
+
+                    int64_t ne2[4] = {
+                        tensor->ne[0],
+                        1,
+                        tensor->ne[1]*tensor->ne[2],
+                        tensor->ne[3]
+                    };
+                    struct ggml_tensor * tensor2 = ggml_cont(ctx,
+                        ggml_reshape_4d(ctx,
+                            ggml_cont(ctx, tensor),
+                            ne2[0], ne2[1], ne2[2], ne2[3]));
+
+                    struct ggml_tensor * grad2 = ggml_cont(ctx,
+                        ggml_reshape_4d(ctx,
+                            ggml_cont(ctx, tensor->grad),
+                            ne2[0], ne2[1], ne2[2], ne2[3]));
+
+                    struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
+                        ggml_permute(ctx,                           // [1,ne0,ne1*ne2,ne3]
+                            tensor2,                                // [ne0,1,ne1*ne2,ne3]
+                            1, 0, 2, 3));
+
+                    src0->grad =
+                        ggml_add_impl(ctx,
+                            src0->grad,                   // [ne0,ne1,ne2,ne3]
+                            ggml_reshape(ctx,             // [ne0,ne1,ne2,ne3]
+                                ggml_mul_mat(ctx,         // [ne0,1,ne1*ne2,ne3]
+                                    ggml_sub(ctx,         // [ne0,ne0,ne1*ne2,ne3]
+                                        ggml_diag(ctx,    // [ne0,ne0,ne1*ne2,ne3]
+                                            tensor2),     // [ne0,1,ne1*ne2,ne3]
+                                        ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
+                                            tensor2_t,    // [1,ne0,ne1*ne2,ne3]
+                                            tensor2_t)),  // [1,ne0,ne1*ne2,ne3]
+                                    grad2),               // [ne0,1,ne1*ne2,ne3]
+                                src0->grad),
+                            inplace);
+                }
             } break;
         case GGML_OP_ROPE:
             {
-                GGML_ASSERT(false); // TODO: not implemented
+                // necessary for llama
+                if (src0->grad) {
+                    assert(src1->type == GGML_TYPE_I32);
+                    assert(ggml_nelements(src1) == 3);
+                    const int n_past = ((int32_t *) src1->data)[0];
+                    const int n_dims = ((int32_t *) src1->data)[1];
+                    const int mode   = ((int32_t *) src1->data)[2];
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            ggml_rope_back(ctx,
+                                tensor->grad,
+                                n_past,
+                                n_dims,
+                                mode),
+                            inplace);
+                }
+                if (src1->grad) {
+                    // noop
+                }
+            } break;
+        case GGML_OP_ROPE_BACK:
+            {
+                if (src0->grad) {
+                    assert(src1->type == GGML_TYPE_I32);
+                    assert(ggml_nelements(src1) == 3);
+                    const int n_past = ((int32_t *) src1->data)[0];
+                    const int n_dims = ((int32_t *) src1->data)[1];
+                    const int mode   = ((int32_t *) src1->data)[2];
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            ggml_rope(ctx,
+                                tensor->grad,
+                                n_past,
+                                n_dims,
+                                mode),
+                            inplace);
+                }
+                if (src1->grad) {
+                    // noop
+                }
             } break;
         case GGML_OP_CONV_1D_1S:
             {
@@ -10516,6 +13604,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_ADD:
+                case GGML_OP_ADD1:
                     {
                         node->n_tasks = n_threads;
 
@@ -10525,14 +13614,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                             cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
                         }
 
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_ACC:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        if (ggml_is_quantized(node->src0->type)) {
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
+                        }
+
                         work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_SUB:
-                case GGML_OP_MUL:
                 case GGML_OP_DIV:
                 case GGML_OP_SQR:
                 case GGML_OP_SQRT:
+                case GGML_OP_LOG:
                 case GGML_OP_SUM:
+                case GGML_OP_SUM_ROWS:
                 case GGML_OP_MEAN:
                 case GGML_OP_REPEAT:
                 case GGML_OP_ABS:
@@ -10543,16 +13645,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = 1;
                     } break;
+                case GGML_OP_MUL:
                 case GGML_OP_GELU:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
                 case GGML_OP_SILU:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
+                case GGML_OP_SILU_BACK:
                 case GGML_OP_NORM:
                 case GGML_OP_RMS_NORM:
+                case GGML_OP_RMS_NORM_BACK:
                     {
                         node->n_tasks = n_threads;
                     } break;
@@ -10618,21 +13717,23 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = n_threads;
                     } break;
+                case GGML_OP_SET:
                 case GGML_OP_CONT:
                 case GGML_OP_RESHAPE:
                 case GGML_OP_VIEW:
                 case GGML_OP_PERMUTE:
                 case GGML_OP_TRANSPOSE:
                 case GGML_OP_GET_ROWS:
-                case GGML_OP_DIAG_MASK_INF:
+                case GGML_OP_GET_ROWS_BACK:
+                case GGML_OP_DIAG:
+                case GGML_OP_DIAG_MASK_ZERO:
                     {
                         node->n_tasks = 1;
                     } break;
+                case GGML_OP_DIAG_MASK_INF:
                 case GGML_OP_SOFT_MAX:
-                    {
-                        node->n_tasks = n_threads;
-                    } break;
                 case GGML_OP_ROPE:
+                case GGML_OP_ROPE_BACK:
                     {
                         node->n_tasks = n_threads;
                     } break;
diff --git a/third_party/ggml/ggml.h b/third_party/ggml/ggml.h
index 5cf9e4de6..26e1cc399 100644
--- a/third_party/ggml/ggml.h
+++ b/third_party/ggml/ggml.h
@@ -226,6 +226,11 @@ COSMOPOLITAN_C_START_
         GGML_TYPE_COUNT,
     };
 
+    enum ggml_backend {
+        GGML_BACKEND_CPU = 0,
+        GGML_BACKEND_CUDA = 1,
+    };
+
     // model file types
     enum ggml_ftype {
         GGML_FTYPE_UNKNOWN     = -1,
@@ -246,12 +251,16 @@ COSMOPOLITAN_C_START_
 
         GGML_OP_DUP,
         GGML_OP_ADD,
+        GGML_OP_ADD1,
+        GGML_OP_ACC,
         GGML_OP_SUB,
         GGML_OP_MUL,
         GGML_OP_DIV,
         GGML_OP_SQR,
         GGML_OP_SQRT,
+        GGML_OP_LOG,
         GGML_OP_SUM,
+        GGML_OP_SUM_ROWS,
         GGML_OP_MEAN,
         GGML_OP_REPEAT,
         GGML_OP_ABS,
@@ -261,12 +270,15 @@ COSMOPOLITAN_C_START_
         GGML_OP_RELU,
         GGML_OP_GELU,
         GGML_OP_SILU,
+        GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
         GGML_OP_RMS_NORM,
+        GGML_OP_RMS_NORM_BACK,
 
         GGML_OP_MUL_MAT,
 
         GGML_OP_SCALE,
+        GGML_OP_SET,
         GGML_OP_CPY,
         GGML_OP_CONT,
         GGML_OP_RESHAPE,
@@ -274,9 +286,13 @@ COSMOPOLITAN_C_START_
         GGML_OP_PERMUTE,
         GGML_OP_TRANSPOSE,
         GGML_OP_GET_ROWS,
+        GGML_OP_GET_ROWS_BACK,
+        GGML_OP_DIAG,
         GGML_OP_DIAG_MASK_INF,
+        GGML_OP_DIAG_MASK_ZERO,
         GGML_OP_SOFT_MAX,
         GGML_OP_ROPE,
+        GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
         GGML_OP_CONV_1D_1S,
         GGML_OP_CONV_1D_2S,
@@ -305,7 +321,8 @@ COSMOPOLITAN_C_START_
 
     // n-dimensional tensor
     struct ggml_tensor {
-        enum ggml_type type;
+        enum ggml_type    type;
+        enum ggml_backend backend;
 
         int     n_dims;
         int64_t ne[GGML_MAX_DIMS]; // number of elements
@@ -336,7 +353,7 @@ COSMOPOLITAN_C_START_
 
         char name[32];
 
-        char padding[8]; // TODO: remove and add padding to name?
+        char padding[16]; // TODO: remove and add padding to name?
     };
 
     // computation graph
@@ -487,6 +504,29 @@ COSMOPOLITAN_C_START_
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_add1(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_acc(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_acc_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
     GGML_API struct ggml_tensor * ggml_sub(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -510,12 +550,24 @@ COSMOPOLITAN_C_START_
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_log(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_log_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // return scalar
-    // TODO: compute sum along rows
     GGML_API struct ggml_tensor * ggml_sum(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
+    GGML_API struct ggml_tensor * ggml_sum_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // mean along rows
     GGML_API struct ggml_tensor * ggml_mean(
             struct ggml_context * ctx,
@@ -557,6 +609,13 @@ COSMOPOLITAN_C_START_
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // a - x
+    // b - dy
+    GGML_API struct ggml_tensor * ggml_silu_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // normalize along rows
     // TODO: eps is hardcoded to 1e-5 for now
     GGML_API struct ggml_tensor * ggml_norm(
@@ -567,6 +626,13 @@ COSMOPOLITAN_C_START_
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // a - x
+    // b - dy
+    GGML_API struct ggml_tensor * ggml_rms_norm_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // A: m rows, n columns
     // B: p rows, n columns (i.e. we transpose it internally)
     // result is m columns, p rows
@@ -579,12 +645,66 @@ COSMOPOLITAN_C_START_
     // operations on tensors without backpropagation
     //
 
-    // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_scale(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_scale_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // b -> view(a,offset,nb1,nb2,3), return modified a
+    GGML_API struct ggml_tensor * ggml_set(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    // b -> view(a,offset,nb1,nb2,3), return view(a)
+    GGML_API struct ggml_tensor * ggml_set_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_set_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_set_1d_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                offset);
+
+    // b -> view(a,offset,nb1,nb2,3), return modified a
+    GGML_API struct ggml_tensor * ggml_set_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                offset);
+
+    // b -> view(a,offset,nb1,nb2,3), return view(a)
+    GGML_API struct ggml_tensor * ggml_set_2d_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                offset);
+
+
     // a -> b, return view(b)
     GGML_API struct ggml_tensor * ggml_cpy(
             struct ggml_context * ctx,
@@ -603,6 +723,13 @@ COSMOPOLITAN_C_START_
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // return view(a)
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0);
+
     // return view(a)
     // TODO: when we start computing gradient, make a copy instead of view
     GGML_API struct ggml_tensor * ggml_reshape_2d(
@@ -620,6 +747,14 @@ COSMOPOLITAN_C_START_
             int64_t               ne1,
             int64_t               ne2);
 
+    GGML_API struct ggml_tensor * ggml_reshape_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3);
+
     // offset in bytes
     GGML_API struct ggml_tensor * ggml_view_1d(
             struct ggml_context * ctx,
@@ -645,6 +780,18 @@ COSMOPOLITAN_C_START_
             size_t                nb2, // slice stride in bytes
             size_t                offset);
 
+    GGML_API struct ggml_tensor * ggml_view_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3,
+            size_t                nb1, // row   stride in bytes
+            size_t                nb2, // slice stride in bytes
+            size_t                nb3,
+            size_t                offset);
+
     GGML_API struct ggml_tensor * ggml_permute(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -663,18 +810,49 @@ COSMOPOLITAN_C_START_
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_get_rows_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c);
+
+    GGML_API struct ggml_tensor * ggml_diag(
+        struct ggml_context     * ctx,
+        struct ggml_tensor      * a);
+
     // set elements above the diagonal to -INF
-    // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_diag_mask_inf(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   n_past);
 
     // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // set elements above the diagonal to 0
+    GGML_API struct ggml_tensor * ggml_diag_mask_zero(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * gml_diag_mask_zero_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
     GGML_API struct ggml_tensor * ggml_soft_max(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // rotary position embedding
     // in-place, returns view(a)
     // if mode & 1 == 1, skip n_past elements
@@ -687,6 +865,23 @@ COSMOPOLITAN_C_START_
             int                   n_dims,
             int                   mode);
 
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode);
+
+    // rotary position embedding backward, i.e compute dx from dy
+    // a - dy
+    GGML_API struct ggml_tensor * ggml_rope_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode);
+
     // alibi position embedding
     // in-place, returns view(a)
     struct ggml_tensor * ggml_alibi(
@@ -731,13 +926,13 @@ COSMOPOLITAN_C_START_
     GGML_API struct ggml_tensor * ggml_map_unary_f32(
             struct ggml_context        * ctx,
             struct ggml_tensor         * a,
-            const  ggml_unary_op_f32_t fun);
+                   ggml_unary_op_f32_t   fun);
 
     GGML_API struct ggml_tensor * ggml_map_binary_f32(
             struct ggml_context         * ctx,
             struct ggml_tensor          * a,
             struct ggml_tensor          * b,
-            const  ggml_binary_op_f32_t fun);
+                   ggml_binary_op_f32_t   fun);
 
     //
     // automatic differentiation
diff --git a/third_party/ggml/llama.cc b/third_party/ggml/llama.cc
index 610f54b22..bfd2e8c9c 100644
--- a/third_party/ggml/llama.cc
+++ b/third_party/ggml/llama.cc
@@ -105,26 +105,26 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
 // 2*n_embd*n_ctx*n_layer*sizeof(float16)
 static const std::map<e_model, size_t> & MEM_REQ_KV_SELF()
 {
-    static std::map<e_model, size_t> _MEM_REQ_KV_SELF = {
+    static std::map<e_model, size_t> k_sizes = {
         { MODEL_7B,   1026ull * MB },
         { MODEL_13B,  1608ull * MB },
         { MODEL_30B,  3124ull * MB },
         { MODEL_65B,  5120ull * MB },
     };
-    return _MEM_REQ_KV_SELF;
+    return k_sizes;
 }
 
 // this is mostly needed for temporary mul_mat buffers to dequantize the data
 // not actually needed if BLAS is disabled
 static const std::map<e_model, size_t> & MEM_REQ_EVAL()
 {
-    static std::map<e_model, size_t> _MEM_REQ_EVAL = {
+    static std::map<e_model, size_t> k_sizes = {
         { MODEL_7B,   768ull * MB },
         { MODEL_13B, 1024ull * MB },
         { MODEL_30B, 1280ull * MB },
         { MODEL_65B, 1536ull * MB },
     };
-    return _MEM_REQ_EVAL;
+    return k_sizes;
 }
 
 // default hparams (LLaMA 7B)
@@ -681,7 +681,7 @@ struct llama_model_loader {
         }
     }
 
-    struct ggml_tensor * get_tensor(const std::string & name, std::vector<uint32_t> ne) {
+    struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne) {
         auto it = tensors_map.name_to_idx.find(name);
         if (it == tensors_map.name_to_idx.end()) {
             Die("llama.cpp: tensor '%s' is missing from model", name.c_str());
@@ -1131,7 +1131,7 @@ static bool llama_eval_internal(
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
 
-    auto & kv_self = model.kv_self;
+    const auto & kv_self = model.kv_self;
 
     LLAMA_ASSERT(!!kv_self.ctx);
 
@@ -1184,8 +1184,8 @@ static bool llama_eval_internal(
         // self-attention
         {
             // compute Q and K and RoPE them
-            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
-            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
             ggml_set_name(Qcur, "Qcur");
             ggml_set_name(Kcur, "Kcur");
 
@@ -1226,17 +1226,19 @@ static bool llama_eval_internal(
             struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
             ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
 
-            struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
+            // KQ_scaled shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
             ggml_set_name(KQ_scaled, "KQ_scaled");
 
             // KQ_masked = mask_past(KQ_scaled)
-            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
             ggml_set_name(KQ_masked, "KQ_masked");
 
             // KQ = soft_max(KQ_masked)
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
             ggml_set_name(KQ_soft_max, "KQ_soft_max");
 
+
             // split cached V into n_head heads
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, kv_self.v,
@@ -1337,7 +1339,7 @@ static bool llama_eval_internal(
     lctx.use_buf(ctx0, -1);
 
     // logits -> probs
-    //inpL = ggml_soft_max(ctx0, inpL);
+    //inpL = ggml_soft_max_inplace(ctx0, inpL);
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
@@ -1375,7 +1377,7 @@ static bool llama_eval_internal(
     }
 
     // extract embeddings
-    if (lctx.embedding.size()) {
+    if (!lctx.embedding.empty()) {
         auto & embedding_out = lctx.embedding;
 
         embedding_out.resize(n_embd);
@@ -1426,6 +1428,8 @@ struct llama_sp_symbol {
     size_t n;
 };
 
+static_assert(std::is_trivially_copyable<llama_sp_symbol>::value, "llama_sp_symbol is not trivially copyable");
+
 struct llama_sp_bigram {
     struct comparator {
         bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) {
@@ -1458,7 +1462,7 @@ struct llama_tokenizer {
             sym.prev = index - 1;
             sym.next = offs == text.size() ? -1 : index + 1;
             index++;
-            symbols_.emplace_back(std::move(sym));
+            symbols_.emplace_back(sym);
         }
 
         // seed the work queue with all possible 2-character tokens.
@@ -1549,7 +1553,7 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
     llama_tokenizer tokenizer(vocab);
     std::vector<llama_vocab::id> output;
 
-    if (text.size() == 0) {
+    if (text.empty()) {
         return output;
     }
 
@@ -1785,7 +1789,7 @@ void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_dat
     const int64_t t_start_sample_us = ggml_time_us();
 
     for (size_t i = 0; i < candidates->size; ++i) {
-        auto token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id);
+        const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id);
         if (token_iter == last_tokens + last_tokens_size) {
             continue;
         }
@@ -1929,7 +1933,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
     const int64_t t_start_sample_us = ggml_time_us();
 
     // Find max element
-    auto max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+    auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
         return a.logit < b.logit;
     });
 
@@ -2286,7 +2290,8 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
         fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model);
         model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false));
 
-        size_t ctx_size, mmapped_size;
+        size_t ctx_size;
+        size_t mmapped_size;
         model_loader->calc_sizes(&ctx_size, &mmapped_size);
         base_buf.resize(ctx_size);
 
@@ -2325,8 +2330,12 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
             fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
         }
 
-        std::string name(length, 0);
-        fin.read(&name[0], length);
+        std::string name;
+        {
+            char buf[1024];
+            fin.read(buf, length);
+            name = std::string(buf, length);
+        }
 
         // check for lora suffix and get the type of tensor
         const std::string lora_suffix = ".lora";
@@ -2341,7 +2350,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
         base_name.erase(pos);
         // fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str());
 
-        if (model_tensors.find(base_name.data()) == model_tensors.end()) {
+        if (model_tensors.find(base_name) == model_tensors.end()) {
             fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
             return 1;
         }
@@ -2421,7 +2430,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
 
             if (scaling != 1.0f) {
                 ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
-                BA = ggml_scale(lora_ctx, BA, scale_tensor);
+                BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
             }
 
             ggml_tensor * r;
@@ -2443,8 +2452,9 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
             lora_tensors.clear();
 
             n_tensors++;
-            if (n_tensors % 4 == 0)
+            if (n_tensors % 4 == 0) {
                 fprintf(stderr, ".");
+            }
         }
     }
 
@@ -2462,7 +2472,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
 
 int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
     // try {
-    return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads);
+        return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads);
     // } catch (const std::string & err) {
     //     fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.c_str());
     //     return 1;
@@ -2473,7 +2483,7 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
     return ctx->model.kv_self.n;
 }
 
-#define LLAMA_MAX_RNG_STATE 64*1024
+#define LLAMA_MAX_RNG_STATE (64*1024)
 
 void llama_set_rng_seed(struct llama_context * ctx, int seed) {
     if (seed < 0) {
@@ -2482,7 +2492,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
     ctx->rng.seed(seed);
 }
 
-// Returns the size of the state
+// Returns the *maximum* size of the state
 size_t llama_get_state_size(const struct llama_context * ctx) {
     // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
     // for reference, std::mt19937(1337) serializes to 6701 bytes.
@@ -2514,8 +2524,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
 }
 
 // Copies the state to the specified destination address
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
-    uint8_t * out = dest;
+size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
+    uint8_t * out = dst;
 
     // copy rng
     {
@@ -2575,9 +2585,10 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
 
         if (kv_size) {
             const size_t elt_size = ggml_element_size(kv_self.k);
-            llama_buffer buffer;
-            buffer.resize(4096);
-            ggml_context * cpy_ctx = ggml_init({ buffer.size, buffer.addr, /* no_alloc */ true });
+
+            char buffer[4096];
+
+            ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
             ggml_cgraph gf{};
             gf.n_threads = 1;
 
@@ -2600,10 +2611,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
             ggml_graph_compute(cpy_ctx, &gf);
+
+            ggml_free(cpy_ctx);
         }
     }
 
-    const size_t written  = out - dest;
+    const size_t written  = out - dst;
     const size_t max_size = llama_get_state_size(ctx);
 
     LLAMA_ASSERT(written <= max_size);