Merge branch 'master' into concedo_experimental
This commit is contained in:
commit
cee018960e
5 changed files with 260 additions and 45 deletions
|
@ -53,7 +53,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
auto end_t = std::chrono::high_resolution_clock::now();
|
||||
if (i == 0) {
|
||||
const float seconds = std::chrono::duration<float>(end_t - start_t).count();
|
||||
printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
|
||||
printf("%.2f seconds per pass - ETA ", seconds);
|
||||
int total_seconds = (int)(seconds * seq_count);
|
||||
if (total_seconds >= 60*60) {
|
||||
printf("%d hours ", total_seconds / (60*60));
|
||||
total_seconds = total_seconds % (60*60);
|
||||
}
|
||||
printf("%d minutes\n", total_seconds / 60);
|
||||
}
|
||||
// We get the logits for all the tokens in the context window (params.n_ctx)
|
||||
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
|
||||
|
|
116
ggml.c
116
ggml.c
|
@ -659,9 +659,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
|
|||
#define QK8_0 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
float s; // d * sum(qs[i])
|
||||
int8_t qs[QK8_0]; // quants
|
||||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
||||
static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
|
@ -1301,13 +1302,39 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
|
|||
|
||||
y[i].d = d;
|
||||
|
||||
int sum = 0;
|
||||
for (int l = 0; l < QK8_0; ++l) {
|
||||
const float v = x[i*QK8_0 + l]*id;
|
||||
y[i].qs[l] = roundf(v);
|
||||
sum += y[i].qs[l];
|
||||
}
|
||||
y[i].s = d * sum;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
// There is no better way of doing this?
|
||||
// I guess not, AVX is not very good at horizontal sums.
|
||||
// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
|
||||
// faster than the solution below. As I don't have an AVX2 system handt right now to test,
|
||||
// keeping the original.
|
||||
// TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
|
||||
//static inline float horizontal_sum(__m256i a) {
|
||||
// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
|
||||
// __m256i sum = _mm256_add_epi32(a, b);
|
||||
// __m256i hi = _mm256_unpackhi_epi64(sum, sum);
|
||||
// sum = _mm256_add_epi32(sum, hi);
|
||||
// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
|
||||
//}
|
||||
static inline float horizontal_sum(__m256i a) {
|
||||
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1));
|
||||
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
||||
__m128i sum64 = _mm_add_epi32(hi64, sum128);
|
||||
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
||||
}
|
||||
#endif
|
||||
|
||||
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
||||
assert(k % QK8_0 == 0);
|
||||
const int nb = k / QK8_0;
|
||||
|
@ -1334,6 +1361,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|||
|
||||
y[i].d = d;
|
||||
|
||||
int32x4_t accv = vdupq_n_s32(0);
|
||||
|
||||
for (int l = 0; l < 8; l++) {
|
||||
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
||||
const int32x4_t vi = vcvtnq_s32_f32(v);
|
||||
|
@ -1342,7 +1371,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|||
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
||||
|
||||
accv = vaddq_s32(accv, vi);
|
||||
}
|
||||
int32_t sum = vaddvq_s32(accv);
|
||||
y[i].s = d * sum;
|
||||
}
|
||||
#elif defined(__AVX2__) || defined(__AVX__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
@ -1390,6 +1423,10 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|||
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||
|
||||
#if defined(__AVX2__)
|
||||
|
||||
// Compute the sum of the quants and set y[i].s
|
||||
y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
||||
|
||||
// Convert int32 to int16
|
||||
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
||||
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
||||
|
@ -1432,6 +1469,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|||
// scalar
|
||||
quantize_row_q8_0_reference(x, y, k);
|
||||
#endif
|
||||
#if defined __AVX__
|
||||
// TODO: vectorize this
|
||||
for (int i=0; i<nb; ++i) {
|
||||
int sum = 0;
|
||||
for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l];
|
||||
y[i].s = y[i].d * sum;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
||||
|
@ -2374,14 +2419,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
float sum8 = 0;
|
||||
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q4_0 * restrict x0 = &x[i + 0];
|
||||
const block_q4_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
|
||||
sum8 += x0->d * y0->s + x1->d * y1->s;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
const int8x16_t s8b = vdupq_n_s8(0x8);
|
||||
|
||||
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
||||
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
||||
|
@ -2392,12 +2440,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
||||
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||
|
||||
// sub 8
|
||||
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// load y
|
||||
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
||||
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
||||
|
@ -2412,21 +2454,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int32x4_t
|
||||
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
|
||||
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
|
||||
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
|
||||
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
||||
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
||||
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
|
||||
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
|
||||
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
|
||||
|
||||
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
||||
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
||||
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
||||
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
||||
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
|
||||
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
|
||||
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
|
||||
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
|
||||
|
||||
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
|
@ -2438,7 +2480,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||
#endif
|
||||
}
|
||||
|
||||
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
|
||||
#elif defined(__AVX2__)
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
@ -2571,12 +2613,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
float summs = 0;
|
||||
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q4_1 * restrict x0 = &x[i + 0];
|
||||
const block_q4_1 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
|
||||
summs += x0->m * y0->s + x1->m * y1->s;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
|
||||
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
||||
|
@ -2600,17 +2646,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|||
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
||||
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
||||
|
||||
const int16x8_t s0i = vaddq_s16(
|
||||
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
|
||||
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
|
||||
|
||||
const int16x8_t s1i = vaddq_s16(
|
||||
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
|
||||
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int32x4_t
|
||||
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
|
||||
|
@ -2639,24 +2674,26 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|||
#endif
|
||||
}
|
||||
|
||||
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
||||
#elif defined(__AVX2__)
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
float summs = 0;
|
||||
|
||||
// Main loop
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float * d0 = &x[i].d;
|
||||
const float * d1 = &y[i].d;
|
||||
const float * m0 = &x[i].m;
|
||||
//const float * m0 = &x[i].m;
|
||||
|
||||
summs += x[i].m * y[i].s;
|
||||
|
||||
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
||||
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
||||
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
||||
|
||||
// Compute combined scales
|
||||
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
||||
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
|
||||
|
||||
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
||||
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
|
@ -2678,15 +2715,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|||
|
||||
// Accumulate d0*d1*x*y
|
||||
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
||||
|
||||
// Compute sum of y values
|
||||
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
||||
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
||||
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
|
||||
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
|
||||
|
||||
// Accumulate d1*m0*y
|
||||
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
|
||||
}
|
||||
|
||||
// Return horizontal sum of the acc vector
|
||||
|
@ -2695,7 +2723,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
|||
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
||||
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
||||
|
||||
sumf = _mm_cvtss_f32( res );
|
||||
sumf = _mm_cvtss_f32( res ) + summs;
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
|
|
@ -2099,7 +2099,11 @@ void llama_set_kv_cache(
|
|||
int n_token_count) {
|
||||
// Make sure we have the same kv cache setup
|
||||
LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size);
|
||||
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
|
||||
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
||||
memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size);
|
||||
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
|
||||
ctx->model.kv_self.v->data = v_data;
|
||||
ctx->model.kv_self.n = n_token_count;
|
||||
}
|
||||
|
||||
|
|
|
@ -2,3 +2,8 @@ set(TARGET vdot)
|
|||
add_executable(${TARGET} vdot.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
|
||||
set(TARGET q8dot)
|
||||
add_executable(${TARGET} q8dot.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
|
|
172
pocs/vdot/q8dot.cpp
Normal file
172
pocs/vdot/q8dot.cpp
Normal file
|
@ -0,0 +1,172 @@
|
|||
#include <cstdio>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ggml.h>
|
||||
|
||||
constexpr int kVecSize = 1 << 16;
|
||||
|
||||
// Copy-pasted from ggml.c
|
||||
#define QK4_0 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
} block_q4_0;
|
||||
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
|
||||
#define QK4_1 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
float m; // min
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
} block_q4_1;
|
||||
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||
|
||||
// Copy-pasted from ggml.c
|
||||
#define QK8_0 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
float s; // d * sum(qs[i])
|
||||
int8_t qs[QK8_0]; // quants
|
||||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
||||
static_assert(QK4_1 == QK8_0, "QK4_1 and QK8_0 must be the same");
|
||||
static_assert(QK4_0 == QK8_0, "QK4_0 and QK8_0 must be the same");
|
||||
|
||||
template <typename T>
|
||||
void fillQ4blocks(std::vector<T>& blocks, std::mt19937& rndm) {
|
||||
for (auto& b : blocks) {
|
||||
b.d = 1;
|
||||
for (int i=0; i<QK4_1/2; ++i) {
|
||||
uint8_t v1 = rndm() >> 28;
|
||||
uint8_t v2 = rndm() >> 28;
|
||||
b.qs[i] = v1 | (v2 << 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fillQ80blocks(std::vector<block_q8_0>& blocks, std::mt19937& rndm) {
|
||||
for (auto& b : blocks) {
|
||||
b.d = 1;
|
||||
int sum = 0;
|
||||
for (int i=0; i<QK8_0; ++i) {
|
||||
b.qs[i] = (rndm() >> 24) - 128;
|
||||
sum += b.qs[i];
|
||||
}
|
||||
b.s = b.d * sum;
|
||||
}
|
||||
}
|
||||
|
||||
float simpleDot(const block_q4_0& x, const block_q8_0& y) {
|
||||
int s1 = 0; //, s2 = 0;
|
||||
for (int i=0; i<QK4_1/2; i+=2) {
|
||||
int v1 = x.qs[i+0] & 0xf;
|
||||
int v2 = x.qs[i+0] >> 4;
|
||||
int v3 = x.qs[i+1] & 0xf;
|
||||
int v4 = x.qs[i+1] >> 4;
|
||||
int j = 2*i;
|
||||
s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3];
|
||||
//s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3];
|
||||
}
|
||||
return y.d * x.d * s1 - 8 * x.d * y.s;
|
||||
//return y.d * x.d * (s1 - 8 * s2);
|
||||
}
|
||||
|
||||
float simpleDot(const block_q4_1& x, const block_q8_0& y) {
|
||||
int s1 = 0; //, s2 = 0;
|
||||
for (int i=0; i<QK4_1/2; i+=2) {
|
||||
int v1 = x.qs[i+0] & 0xf;
|
||||
int v2 = x.qs[i+0] >> 4;
|
||||
int v3 = x.qs[i+1] & 0xf;
|
||||
int v4 = x.qs[i+1] >> 4;
|
||||
int j = 2*i;
|
||||
s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3];
|
||||
//s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3];
|
||||
}
|
||||
return y.d * x.d * s1 + y.s * x.m;
|
||||
//return y.d * (x.d * s1 + x.m * s2);
|
||||
}
|
||||
|
||||
struct Stat {
|
||||
double sum = 0, sumt = 0, sumt2 = 0, maxt = 0;
|
||||
int nloop = 0;
|
||||
void addResult(double s, double t) {
|
||||
sum += s;
|
||||
sumt += t; sumt2 += t*t; maxt = std::max(maxt, t);
|
||||
++nloop;
|
||||
}
|
||||
void reportResult(const char* title) const {
|
||||
if (nloop < 1) {
|
||||
printf("%s(%s): no result\n",__func__,title);
|
||||
return;
|
||||
}
|
||||
printf("============ %s\n",title);
|
||||
printf("<dot> = %g\n",sum/nloop);
|
||||
auto t = sumt/nloop, dt = sumt2/nloop - t*t;
|
||||
if (dt > 0) dt = sqrt(dt);
|
||||
printf("<time> = %g +/- %g us. Max. time = %g us.\n",t,dt,maxt);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
int nloop = argc > 1 ? atoi(argv[1]) : 10;
|
||||
int type = argc > 2 ? atoi(argv[2]) : 1;
|
||||
|
||||
std::mt19937 rndm(1234);
|
||||
|
||||
std::vector<block_q4_1> x41;
|
||||
std::vector<block_q4_0> x40;
|
||||
std::vector<block_q8_0> y(kVecSize);
|
||||
if (type == 0) x40.resize(kVecSize);
|
||||
else {
|
||||
x41.resize(kVecSize);
|
||||
for (auto& b : x41) b.m = 1;
|
||||
}
|
||||
|
||||
auto ggml_type = type == 0 ? GGML_TYPE_Q4_0 : GGML_TYPE_Q4_1;
|
||||
|
||||
auto funcs = ggml_internal_get_quantize_fn(ggml_type);
|
||||
|
||||
Stat simple, ggml;
|
||||
|
||||
for (int iloop=0; iloop<nloop; ++iloop) {
|
||||
|
||||
if (type == 0) fillQ4blocks(x40, rndm);
|
||||
else fillQ4blocks(x41, rndm);
|
||||
fillQ80blocks(y, rndm);
|
||||
|
||||
auto t1 = std::chrono::high_resolution_clock::now();
|
||||
double s = 0;
|
||||
if (type == 0) for (int i=0; i<kVecSize; ++i) s += simpleDot(x40[i], y[i]);
|
||||
else for (int i=0; i<kVecSize; ++i) s += simpleDot(x41[i], y[i]);
|
||||
auto t2 = std::chrono::high_resolution_clock::now();
|
||||
auto t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
|
||||
if (iloop > 3) simple.addResult(s, t);
|
||||
|
||||
t1 = std::chrono::high_resolution_clock::now();
|
||||
float fs;
|
||||
if (type == 0) funcs.vec_dot_q(kVecSize * QK4_1, &fs, x40.data(), y.data());
|
||||
else funcs.vec_dot_q(kVecSize * QK4_1, &fs, x41.data(), y.data());
|
||||
t2 = std::chrono::high_resolution_clock::now();
|
||||
t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
|
||||
if (iloop > 3) ggml.addResult(fs, t);
|
||||
|
||||
}
|
||||
|
||||
// Report the time (and the average of the dot products so the compiler does not come up with the idea
|
||||
// of optimizing away the function calls after figuring that the result is not used).
|
||||
simple.reportResult("Simple");
|
||||
ggml.reportResult("ggml");
|
||||
return 0;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue