From 8b9316be70148f63737b13b406420a0631cc1b06 Mon Sep 17 00:00:00 2001 From: qwopqwop200 Date: Thu, 13 Apr 2023 15:16:39 +0900 Subject: [PATCH] fix tab --- ggml.c | 176 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 88 insertions(+), 88 deletions(-) diff --git a/ggml.c b/ggml.c index 2081e26fd..a6051a03f 100644 --- a/ggml.c +++ b/ggml.c @@ -1154,9 +1154,9 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int i14 = _mm256_packs_epi32( i14, i15 ); // Convert int16 to int8 i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - i4 = _mm256_packs_epi16( i4, i6 ); - i8 = _mm256_packs_epi16( i8, i10 ); - i12 = _mm256_packs_epi16( i12, i14 ); + i4 = _mm256_packs_epi16( i4, i6 ); + i8 = _mm256_packs_epi16( i8, i10 ); + i12 = _mm256_packs_epi16( i12, i14 ); // We got our precious signed bytes, but the order is now wrong // These AVX2 pack instructions process 16-byte pieces independently @@ -1188,8 +1188,8 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int for (int l = 0; l < 32; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l); - for (int l = 0; l < 16; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]); - for (int l = 0; l < 8; l++) minv[4*l] = vminq_f32(srcv[4*l], srcv[4*l + 2]); + for (int l = 0; l < 16; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]); + for (int l = 0; l < 8; l++) minv[4*l] = vminq_f32(srcv[4*l], srcv[4*l + 2]); for (int l = 0; l < 4; l++) minv[8*l] = vminq_f32(srcv[8*l], srcv[8*l + 4]); for (int l = 0; l < 2; l++) minv[16*l] = vminq_f32(minv[16*l], minv[16*l + 8]); for (int l = 0; l < 1; l++) minv[32*l] = vminq_f32(minv[32*l], minv[32*l + 16]); @@ -2731,46 +2731,46 @@ static void ggml_vec_dot_q4_2(const int n, float * restrict s, const void * rest const __m256 scale_1 = _mm256_mul_ps( m0v, d1v ); const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ ); - const uint8_t * restrict x_pp = x[i].qs; - const uint8_t * restrict y_pp = y[i].qs; + const uint8_t * restrict x_pp = x[i].qs; + const uint8_t * restrict y_pp = y[i].qs; // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - for (int l = 0; l < QK128; l += 32) { - __m256i bx = bytesFromNibbles( x_pp + l/2); - __m256i by = bytesFromNibbles( y_pp + l/2); + for (int l = 0; l < QK128; l += 32) { + __m256i bx = bytesFromNibbles( x_pp + l/2); + __m256i by = bytesFromNibbles( y_pp + l/2); - // Now we have a vector with bytes in [ 0 .. 15 ] interval. + // Now we have a vector with bytes in [ 0 .. 15 ] interval. - // Sign-extend first 16 signed bytes into int16_t - __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); - __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - // Compute products of int16_t integers, add pairwise - __m256i i32 = _mm256_madd_epi16( x16, y16 ); + // Sign-extend first 16 signed bytes into int16_t + __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); + __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); + // Compute products of int16_t integers, add pairwise + __m256i i32 = _mm256_madd_epi16( x16, y16 ); - // Sign-extend last 16 signed bytes into int16_t vectors - __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); - __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - // Accumulate products of int16_t integers - i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) ); + // Sign-extend last 16 signed bytes into int16_t vectors + __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); + __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); + // Accumulate products of int16_t integers + i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) ); - // compute sums of unsigned bytes in bx, by in blocks of 8. - // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000, - // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400. - // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ] - __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() ); - __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() ); - __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) ); - __m256 sums = _mm256_cvtepi32_ps( sumsi ); + // compute sums of unsigned bytes in bx, by in blocks of 8. + // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000, + // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400. + // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ] + __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() ); + __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() ); + __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) ); + __m256 sums = _mm256_cvtepi32_ps( sumsi ); - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps( i32 ); - // Apply the scale, and accumulate - // acc += d0*d1*x*y + d0*m1*x + d1*m0*y - acc = _mm256_fmadd_ps( scale_01, p, acc ); - acc = _mm256_fmadd_ps( cross_scales, sums, acc ); - } - // acc_offset += m0*m1 (for each entry in the block) - acc_offset += (*m0)*(*m1); + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps( i32 ); + // Apply the scale, and accumulate + // acc += d0*d1*x*y + d0*m1*x + d1*m0*y + acc = _mm256_fmadd_ps( scale_01, p, acc ); + acc = _mm256_fmadd_ps( cross_scales, sums, acc ); + } + // acc_offset += m0*m1 (for each entry in the block) + acc_offset += (*m0)*(*m1); } // Return horizontal sum of the acc vector @@ -2792,35 +2792,35 @@ static void ggml_vec_dot_q4_2(const int n, float * restrict s, const void * rest const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8_t * restrict x_pp = x0->qs; - const uint8_t * restrict y_pp = x0->qs; + const uint8_t * restrict x_pp = x0->qs; + const uint8_t * restrict y_pp = x0->qs; - for (int l = 0; l < QK128; l += 32) { - const uint8x16_t v0_0 = vld1q_u8(x_pp + l/2); - const uint8x16_t v1_0 = vld1q_u8(y_pp + l/2); + for (int l = 0; l < QK128; l += 32) { + const uint8x16_t v0_0 = vld1q_u8(x_pp + l/2); + const uint8x16_t v1_0 = vld1q_u8(y_pp + l/2); - // and with 0xf - const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); - const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + // and with 0xf + const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); + const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); - const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); - const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); - // dot product into uint16x8_t - const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); - const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); + // dot product into uint16x8_t + const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); + const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); - const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); - const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); + const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); + const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); - const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h); - const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h); + const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h); + const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h); - sum00 += x0->m*y0->m; - sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); - sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); - sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0)); - } + sum00 += x0->m*y0->m; + sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); + sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); + sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0)); + } } sumf = QK128*sum00 + sum01 + sum10 + sum11; @@ -3100,7 +3100,7 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F16] = 1, [GGML_TYPE_Q4_0] = QK, [GGML_TYPE_Q4_1] = QK, - [GGML_TYPE_Q4_2] = QK128, + [GGML_TYPE_Q4_2] = QK128, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, @@ -3111,7 +3111,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F16] = sizeof(ggml_fp16_t), [GGML_TYPE_Q4_0] = sizeof(block_q4_0), [GGML_TYPE_Q4_1] = sizeof(block_q4_1), - [GGML_TYPE_Q4_2] = sizeof(block_q4_2), + [GGML_TYPE_Q4_2] = sizeof(block_q4_2), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), @@ -5789,7 +5789,7 @@ static void ggml_compute_forward_dup( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5871,7 +5871,7 @@ static void ggml_compute_forward_add( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5924,7 +5924,7 @@ static void ggml_compute_forward_sub( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5977,7 +5977,7 @@ static void ggml_compute_forward_mul( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6030,7 +6030,7 @@ static void ggml_compute_forward_div( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6079,7 +6079,7 @@ static void ggml_compute_forward_sqr( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6128,7 +6128,7 @@ static void ggml_compute_forward_sqrt( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6187,7 +6187,7 @@ static void ggml_compute_forward_sum( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6265,7 +6265,7 @@ static void ggml_compute_forward_mean( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6330,7 +6330,7 @@ static void ggml_compute_forward_repeat( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6379,7 +6379,7 @@ static void ggml_compute_forward_abs( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6428,7 +6428,7 @@ static void ggml_compute_forward_sgn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6477,7 +6477,7 @@ static void ggml_compute_forward_neg( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6526,7 +6526,7 @@ static void ggml_compute_forward_step( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6575,7 +6575,7 @@ static void ggml_compute_forward_relu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6641,7 +6641,7 @@ static void ggml_compute_forward_gelu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6709,7 +6709,7 @@ static void ggml_compute_forward_silu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6877,7 +6877,7 @@ static void ggml_compute_forward_rms_norm( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7619,7 +7619,7 @@ static void ggml_compute_forward_scale( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7878,7 +7878,7 @@ static void ggml_compute_forward_diag_mask_inf( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7973,7 +7973,7 @@ static void ggml_compute_forward_soft_max( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8149,7 +8149,7 @@ static void ggml_compute_forward_rope( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8418,7 +8418,7 @@ static void ggml_compute_forward_conv_1d_1s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8687,7 +8687,7 @@ static void ggml_compute_forward_conv_1d_2s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -9173,7 +9173,7 @@ static void ggml_compute_forward_flash_attn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -9385,7 +9385,7 @@ static void ggml_compute_forward_flash_ff( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: