diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f9fdd30f..2c3c60167 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,7 +307,7 @@ add_library(ggml OBJECT target_include_directories(ggml PUBLIC .) target_compile_features(ggml PUBLIC c_std_11) # don't bump -target_link_libraries(ggml PRIVATE Threads::Threads ${LLAMA_EXTRA_LIBS}) +target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) if (BUILD_SHARED_LIBS) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() diff --git a/Makefile b/Makefile index f267d0864..3b48eec99 100644 --- a/Makefile +++ b/Makefile @@ -101,11 +101,13 @@ ifdef LLAMA_OPENBLAS LDFLAGS += -lopenblas endif ifdef LLAMA_CUBLAS - CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include - LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 - OBJS += ggml-cuda.o + CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include + LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 + OBJS += ggml-cuda.o + NVCC = nvcc + NVCCFLAGS = --forward-unknown-to-host-linker -arch=native ggml-cuda.o: ggml-cuda.cu ggml-cuda.h - nvcc -arch=native -c -o $@ $< + $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@ endif ifdef LLAMA_GPROF CFLAGS += -pg diff --git a/SHA256SUMS b/SHA256SUMS index 63fac21ae..1d034b371 100644 --- a/SHA256SUMS +++ b/SHA256SUMS @@ -1,12 +1,27 @@ 700df0d3013b703a806d2ae7f1bfb8e59814e3d06ae78be0c66368a50059f33d models/7B/consolidated.00.pth +666a4bb533b303bdaf89e1b6a3b6f93535d868de31d903afdc20983dc526c847 models/7B/ggml-model-f16.bin +fcb7664c2e69776920b526362a243e912f73c36b1ec892eb354bab940f5edb5a models/7B/ggml-model-q4_0.bin +cc061458339a3eb8bcecbf0a825e9924fb7d1a8150f63cd5d091caa99215aafe models/7B/ggml-model-q4_1.bin +1bc7484c24a87612726d756f1761890e7acf5f412e23378577ce50fbe789b5b8 models/7B/ggml-model-q4_2.bin +3429bf198ec771886cf81a574df45245f3ebf04f0ce0956b73ef5d0ab01ff48b models/7B/ggml-model-q4_3.bin 7e89e242ddc0dd6f060b43ca219ce8b3e8f08959a72cb3c0855df8bb04d46265 models/7B/params.json 745bf4e29a4dd6f411e72976d92b452da1b49168a4f41c951cfcc8051823cf08 models/13B/consolidated.00.pth d5ccbcc465c71c0de439a5aeffebe8344c68a519bce70bc7f9f92654ee567085 models/13B/consolidated.01.pth +2b206e9b21fb1076f11cafc624e2af97c9e48ea09312a0962153acc20d45f808 models/13B/ggml-model-f16.bin +4b69e4d6b6e3275230955997b90407fceca7e5ab3daf2e63a2c9e7270a8e1e3e models/13B/ggml-model-q4_0.bin +d9581b5b88e5622532fe897c9f9b0e67a317d22dd27a6f90fa4ab8c6d23ccdbb models/13B/ggml-model-q4_1.bin +8d55a2077317ec9a928c7851d6a43e08e51f7e9e08360f2a7a7e1deefea3134f models/13B/ggml-model-q4_2.bin +4208cdec9788ffa48dc1a17af2c36a0299f5bf3eb0e2b87889dda7fad591fca3 models/13B/ggml-model-q4_3.bin 4ab77bec4d4405ccb66a97b282574c89a94417e3c32e5f68f37e2876fc21322f models/13B/params.json e23294a58552d8cdec5b7e8abb87993b97ea6eced4178ff2697c02472539d067 models/30B/consolidated.00.pth 4e077b7136c7ae2302e954860cf64930458d3076fcde9443f4d0e939e95903ff models/30B/consolidated.01.pth 24a87f01028cbd3a12de551dcedb712346c0b5cbdeff1454e0ddf2df9b675378 models/30B/consolidated.02.pth 1adfcef71420886119544949767f6a56cb6339b4d5fcde755d80fe68b49de93b models/30B/consolidated.03.pth +7e1b524061a9f4b27c22a12d6d2a5bf13b8ebbea73e99f218809351ed9cf7d37 models/30B/ggml-model-f16.bin +7a679908ce31c9d6ae2e38d6059bcd4d0ad3a870cd58cc1c8f7b36f2b2f51c73 models/30B/ggml-model-q4_0.bin +7b75ac615fa369ee593493a7e6ef87542bf0350255db928b22c5a24f6d598bcd models/30B/ggml-model-q4_1.bin +2c82b4954a94a6a284f452f6011c1e4f0d20362c194a0b1eb5737f5fd8a20fb3 models/30B/ggml-model-q4_2.bin +a6188660199dbcb8d5658abe7d89169869e50423494385830d9e6b330ea7fc33 models/30B/ggml-model-q4_3.bin 2c07118ea98d69dbe7810d88520e30288fa994751b337f8fca02b171955f44cb models/30B/params.json 135c563f6b3938114458183afb01adc9a63bef3d8ff7cccc3977e5d3664ecafe models/65B/consolidated.00.pth 9a600b37b19d38c7e43809485f70d17d1dc12206c07efa83bc72bb498a568bde models/65B/consolidated.01.pth @@ -16,5 +31,10 @@ e7babf7c5606f165a3756f527cb0fedc4f83e67ef1290391e52fb1cce5f26770 models/65B/con a287c0dfe49081626567c7fe87f74cce5831f58e459b427b5e05567641f47b78 models/65B/consolidated.05.pth 72b4eba67a1a3b18cb67a85b70f8f1640caae9b40033ea943fb166bd80a7b36b models/65B/consolidated.06.pth d27f5b0677d7ff129ceacd73fd461c4d06910ad7787cf217b249948c3f3bc638 models/65B/consolidated.07.pth +60758f2384d74e423dffddfd020ffed9d3bb186ebc54506f9c4a787d0f5367b0 models/65B/ggml-model-f16.bin +c671fe1bce71499ac732ec999770ebe53ac486623a7891e42c9dfdb6962d2c64 models/65B/ggml-model-q4_0.bin +4743a28aac3e5f32a6e838a815f51d3779de44fbbe251d745251e66c23c5950f models/65B/ggml-model-q4_1.bin +4a145a210c56982389b1ed34387e0590c3e0d7325fa9be4f2284fe4d244a3633 models/65B/ggml-model-q4_2.bin +305e91a4608b4f627b9b8ad5b4af75187d2684254bfd76dcb9db571618ef293c models/65B/ggml-model-q4_3.bin 999ed1659b469ccc2a941714c0a9656fa571d17c9f7c8c7589817ca90edef51b models/65B/params.json 9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347 models/tokenizer.model diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b7b3c4196..65db79263 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,7 +264,7 @@ int main(int argc, char ** argv) { // infinite text generation via context swapping // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches if (n_past + (int) embd.size() > n_ctx) { const int n_left = n_past - params.n_keep; @@ -282,13 +282,21 @@ int main(int argc, char ** argv) { //printf("\n---\n"); } - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; + // evaluate tokens in batches + // embd is typically prepared beforehand to fit within a batch, but not always + for (int i = 0; i < (int) embd.size(); i += params.n_batch) { + int n_eval = (int) embd.size() - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + n_past += n_eval; } } - n_past += embd.size(); embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 80792ea0d..615157e7b 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -53,7 +53,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) { auto end_t = std::chrono::high_resolution_clock::now(); if (i == 0) { const float seconds = std::chrono::duration(end_t - start_t).count(); - printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0)); + printf("%.2f seconds per pass - ETA ", seconds); + int total_seconds = (int)(seconds * seq_count); + if (total_seconds >= 60*60) { + printf("%d hours ", total_seconds / (60*60)); + total_seconds = total_seconds % (60*60); + } + printf("%d minutes\n", total_seconds / 60); } // We get the logits for all the tokens in the context window (params.n_ctx) // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0baa989a3..fa511c1dc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,5 +1,7 @@ #include +#include #include +#include #include "ggml-cuda.h" typedef uint16_t ggml_fp16_t; @@ -29,14 +31,12 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 #define QK4_3 16 typedef struct { - __half d; // delta - __half m; // min - uint8_t qs[QK4_3 / 2]; // nibbles / quants + __half d; // delta + __half m; // min + uint8_t qs[QK4_3 / 2]; // nibbles / quants } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); - - static __global__ void dequantize_block_q4_0(const void * vx, float * y) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -131,24 +131,98 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { } } -extern "C" { - __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_0; - dequantize_block_q4_0<<>>(vx, y); - } +void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_0; + dequantize_block_q4_0<<>>(vx, y); +} - __host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_1; - dequantize_block_q4_1<<>>(vx, y); - } +void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_1; + dequantize_block_q4_1<<>>(vx, y); +} - __host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_2; - dequantize_block_q4_2<<>>(vx, y); - } +void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_2; + dequantize_block_q4_2<<>>(vx, y); +} - __host__ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_3; - dequantize_block_q4_3<<>>(vx, y); +void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_3; + dequantize_block_q4_3<<>>(vx, y); +} + +// buffer pool for cuda +#define MAX_CUDA_BUFFERS 16 + +struct scoped_spin_lock { + std::atomic_flag& lock; + scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { + while (lock.test_and_set(std::memory_order_acquire)) { + ; // spin + } + } + ~scoped_spin_lock() { + lock.clear(std::memory_order_release); + } + scoped_spin_lock(const scoped_spin_lock&) = delete; + scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; +}; + +struct cuda_buffer { + void * ptr = nullptr; + size_t size = 0; +}; + +static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; +static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; + +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { + scoped_spin_lock lock(g_cuda_pool_lock); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[i]; + if (b.size >= size && b.ptr != nullptr) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + void * ptr; + CUDA_CHECK(cudaMalloc((void **) &ptr, size)); + *actual_size = size; + return ptr; +} + +void ggml_cuda_pool_free(void * ptr, size_t size) { + scoped_spin_lock lock(g_cuda_pool_lock); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + CUDA_CHECK(cudaFree(ptr)); +} + +cublasHandle_t g_cublasH = NULL; +cudaStream_t g_cudaStream = NULL; + +void ggml_init_cublas(void) { + if (g_cublasH == NULL) { + // create cublas handle, bind a stream + CUBLAS_CHECK(cublasCreate(&g_cublasH)); + + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking)); + + CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream)); + + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); } } diff --git a/ggml-cuda.h b/ggml-cuda.h index be140606a..370bbc75f 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -1,7 +1,36 @@ +#include +#include + #ifdef __cplusplus extern "C" { #endif +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + exit(1); \ + } \ + } while (0) + +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +extern cublasHandle_t g_cublasH; +extern cudaStream_t g_cudaStream; + +void ggml_init_cublas(void); +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); +void ggml_cuda_pool_free(void * ptr, size_t size); + void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); diff --git a/ggml.c b/ggml.c index 998602150..2ea4e68fd 100644 --- a/ggml.c +++ b/ggml.c @@ -148,44 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) { #elif defined(GGML_USE_OPENBLAS) #include #elif defined(GGML_USE_CUBLAS) -#include -#include #include "ggml-cuda.h" - -#define CUDA_CHECK(err) \ - do { \ - cudaError_t err_ = (err); \ - if (err_ != cudaSuccess) { \ - printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - cudaGetErrorString(err_)); \ - exit(1); \ - } \ - } while (0) - -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ - exit(1); \ - } \ - } while (0) - -static cublasHandle_t cublasH = NULL; -static cudaStream_t cudaStream = NULL; -static void init_cublas(void) { - if (cublasH == NULL) { - // create cublas handle, bind a stream - CUBLAS_CHECK(cublasCreate(&cublasH)); - - CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); - - CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); - - // configure logging to stdout - // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); - } -} #endif #undef MIN @@ -657,9 +620,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong #define QK8_0 32 typedef struct { float d; // delta + float s; // d * sum(qs[i]) int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); // reference implementation for deterministic creation of model files @@ -1299,13 +1263,39 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r y[i].d = d; + int sum = 0; for (int l = 0; l < QK8_0; ++l) { const float v = x[i*QK8_0 + l]*id; y[i].qs[l] = roundf(v); + sum += y[i].qs[l]; } + y[i].s = d * sum; } } +#ifdef __AVX2__ +// There is no better way of doing this? +// I guess not, AVX is not very good at horizontal sums. +// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly +// faster than the solution below. As I don't have an AVX2 system handt right now to test, +// keeping the original. +// TODO: Please try and if it does make a differece, uncomment and remove the implementation below. +//static inline float horizontal_sum(__m256i a) { +// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a))); +// __m256i sum = _mm256_add_epi32(a, b); +// __m256i hi = _mm256_unpackhi_epi64(sum, sum); +// sum = _mm256_add_epi32(sum, hi); +// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); +//} +static inline float horizontal_sum(__m256i a) { + __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); + __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + __m128i sum64 = _mm_add_epi32(hi64, sum128); + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} +#endif + static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1332,6 +1322,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].d = d; + int32x4_t accv = vdupq_n_s32(0); + for (int l = 0; l < 8; l++) { const float32x4_t v = vmulq_n_f32(srcv[l], id); const int32x4_t vi = vcvtnq_s32_f32(v); @@ -1340,7 +1332,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); } + int32_t sum = vaddvq_s32(accv); + y[i].s = d * sum; } #elif defined(__AVX2__) || defined(__AVX__) for (int i = 0; i < nb; i++) { @@ -1388,6 +1384,10 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m256i i3 = _mm256_cvtps_epi32( v3 ); #if defined(__AVX2__) + + // Compute the sum of the quants and set y[i].s + y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 @@ -1430,6 +1430,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // scalar quantize_row_q8_0_reference(x, y, k); #endif +#if defined __AVX__ + // TODO: vectorize this + for (int i=0; id * y0->s + x1->d * y1->s; + const uint8x16_t m4b = vdupq_n_u8(0xf); - const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2390,12 +2401,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); @@ -2410,21 +2415,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); @@ -2436,7 +2441,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2569,12 +2574,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); + float summs = 0; + for (int i = 0; i < nb; i += 2) { const block_q4_1 * restrict x0 = &x[i + 0]; const block_q4_1 * restrict x1 = &x[i + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; + summs += x0->m * y0->s + x1->m * y1->s; + const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t v0_0 = vld1q_u8(x0->qs); @@ -2598,17 +2607,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); - const int16x8_t s0i = vaddq_s16( - vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))), - vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs)))); - - const int16x8_t s1i = vaddq_s16( - vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))), - vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs)))); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d); - #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); @@ -2637,24 +2635,26 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); + float summs = 0; + // Main loop for (int i = 0; i < nb; ++i) { const float * d0 = &x[i].d; const float * d1 = &y[i].d; - const float * m0 = &x[i].m; + //const float * m0 = &x[i].m; + + summs += x[i].m * y[i].s; const __m256 d0v = _mm256_broadcast_ss( d0 ); const __m256 d1v = _mm256_broadcast_ss( d1 ); - const __m256 m0v = _mm256_broadcast_ss( m0 ); // Compute combined scales const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - const __m256 d1m0 = _mm256_mul_ps( d1v, m0v ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes const __m256i bx = bytes_from_nibbles_32(x[i].qs); @@ -2676,15 +2676,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * // Accumulate d0*d1*x*y acc = _mm256_fmadd_ps( d0d1, xy, acc ); - - // Compute sum of y values - const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones ); - const __m256 ysum = _mm256_cvtepi32_ps( ysumi ); - - // Accumulate d1*m0*y - acc = _mm256_fmadd_ps( d1m0, ysum, acc ); } // Return horizontal sum of the acc vector @@ -2693,7 +2684,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - sumf = _mm_cvtss_f32( res ); + sumf = _mm_cvtss_f32( res ) + summs; #else // scalar for (int i = 0; i < nb; i++) { @@ -3720,7 +3711,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize cuBLAS #if defined(GGML_USE_CUBLAS) - init_cublas(); + ggml_init_cublas(); #endif is_first_call = false; @@ -7566,18 +7557,16 @@ static void ggml_compute_forward_mul_mat_f32( } #if defined(GGML_USE_CUBLAS) - float *d_X = NULL; - float *d_Y = NULL; - float *d_D = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); + size_t x_size, y_size, d_size; + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #endif for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -7589,19 +7578,19 @@ static void ggml_compute_forward_mul_mat_f32( #if defined(GGML_USE_CUBLAS) // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute CUBLAS_CHECK( - cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha, d_X, ne00, d_Y, ne10, &beta, d_D, ne01)); // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, @@ -7613,10 +7602,10 @@ static void ggml_compute_forward_mul_mat_f32( } } #if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - CUDA_CHECK(cudaFree(d_X)); - CUDA_CHECK(cudaFree(d_Y)); - CUDA_CHECK(cudaFree(d_D)); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); #endif //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); @@ -7766,18 +7755,16 @@ static void ggml_compute_forward_mul_mat_f16_f32( #if defined(GGML_USE_CUBLAS) ggml_fp16_t * const wdata = params->wdata; - float *d_X = NULL; - float *d_Y = NULL; - float *d_D = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); + size_t x_size, y_size, d_size; + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #else float * const wdata = params->wdata; #endif @@ -7811,12 +7798,12 @@ static void ggml_compute_forward_mul_mat_f16_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream)); - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute CUBLAS_CHECK( - cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha, d_X, CUDA_R_16F, ne00, d_Y, CUDA_R_16F, ne10, @@ -7825,7 +7812,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( CUBLAS_GEMM_DEFAULT)); // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); @@ -7843,10 +7830,10 @@ static void ggml_compute_forward_mul_mat_f16_f32( } #if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - CUDA_CHECK(cudaFree(d_X)); - CUDA_CHECK(cudaFree(d_Y)); - CUDA_CHECK(cudaFree(d_D)); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); #endif /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ @@ -8014,20 +8001,17 @@ static void ggml_compute_forward_mul_mat_q_f32( } #if defined(GGML_USE_CUBLAS) - float *d_X = NULL; - float *d_Y = NULL; - float *d_D = NULL; - float *d_Q = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type])); + size_t x_size, y_size, d_size, q_size; + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL; if (type == GGML_TYPE_Q4_0) { @@ -8057,9 +8041,9 @@ static void ggml_compute_forward_mul_mat_q_f32( // copy and dequantize on device CUDA_CHECK( cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02, - GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream)); + GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream)); - dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream); + dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); CUDA_CHECK(cudaGetLastError()); #else { @@ -8075,18 +8059,18 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_CUBLAS) // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute CUBLAS_CHECK( - cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha, d_X, ne00, d_Y, ne10, &beta, d_D, ne01)); // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, @@ -8099,11 +8083,11 @@ static void ggml_compute_forward_mul_mat_q_f32( } #if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - CUDA_CHECK(cudaFree(d_X)); - CUDA_CHECK(cudaFree(d_Y)); - CUDA_CHECK(cudaFree(d_D)); - CUDA_CHECK(cudaFree(d_Q)); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); + ggml_cuda_pool_free(d_Q, q_size); #endif //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); diff --git a/llama.cpp b/llama.cpp index 4a646eb91..0345b61c6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1618,8 +1618,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // quantize only 2D tensors quantize &= (tensor.ne.size() == 2); - // GG: uncomment this to keep the output layer in FP16 - //if (tensor.name.rfind("output")) { + // uncomment this to keep the output layer in FP16 + //if (tensor.name == "output.weight") { // quantize = false; //} @@ -2092,7 +2092,11 @@ void llama_set_kv_cache( int n_token_count) { // Make sure we have the same kv cache setup LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size); + void * k_data = ctx->model.kv_self.k->data; // remember data pointers + void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size); + ctx->model.kv_self.k->data = k_data; // restore correct data pointers + ctx->model.kv_self.v->data = v_data; ctx->model.kv_self.n = n_token_count; } diff --git a/llama_util.h b/llama_util.h index eba14656a..acb207e65 100755 --- a/llama_util.h +++ b/llama_util.h @@ -21,6 +21,9 @@ #if defined(_POSIX_MAPPED_FILES) #include #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif #endif #endif @@ -303,8 +306,18 @@ struct llama_mlock { if (!mlock(addr, size)) { return true; } else { - fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n" MLOCK_SUGGESTION, - size, this->size, std::strerror(errno)); + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); + + // Check if the resource limit is fine after all + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) + suggest = false; + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) + suggest = false; + + fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); return false; } } diff --git a/pocs/vdot/CMakeLists.txt b/pocs/vdot/CMakeLists.txt index cbc852236..fb89a1cd4 100644 --- a/pocs/vdot/CMakeLists.txt +++ b/pocs/vdot/CMakeLists.txt @@ -2,3 +2,8 @@ set(TARGET vdot) add_executable(${TARGET} vdot.cpp) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) + +set(TARGET q8dot) +add_executable(${TARGET} q8dot.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/pocs/vdot/q8dot.cpp b/pocs/vdot/q8dot.cpp new file mode 100644 index 000000000..5748c8ac2 --- /dev/null +++ b/pocs/vdot/q8dot.cpp @@ -0,0 +1,172 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +constexpr int kVecSize = 1 << 16; + +// Copy-pasted from ggml.c +#define QK4_0 32 +typedef struct { + float d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + float d; // delta + float m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +// Copy-pasted from ggml.c +#define QK8_0 32 +typedef struct { + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); + +static_assert(QK4_1 == QK8_0, "QK4_1 and QK8_0 must be the same"); +static_assert(QK4_0 == QK8_0, "QK4_0 and QK8_0 must be the same"); + +template +void fillQ4blocks(std::vector& blocks, std::mt19937& rndm) { + for (auto& b : blocks) { + b.d = 1; + for (int i=0; i> 28; + uint8_t v2 = rndm() >> 28; + b.qs[i] = v1 | (v2 << 4); + } + } +} + +void fillQ80blocks(std::vector& blocks, std::mt19937& rndm) { + for (auto& b : blocks) { + b.d = 1; + int sum = 0; + for (int i=0; i> 24) - 128; + sum += b.qs[i]; + } + b.s = b.d * sum; + } +} + +float simpleDot(const block_q4_0& x, const block_q8_0& y) { + int s1 = 0; //, s2 = 0; + for (int i=0; i> 4; + int v3 = x.qs[i+1] & 0xf; + int v4 = x.qs[i+1] >> 4; + int j = 2*i; + s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; + //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; + } + return y.d * x.d * s1 - 8 * x.d * y.s; + //return y.d * x.d * (s1 - 8 * s2); +} + +float simpleDot(const block_q4_1& x, const block_q8_0& y) { + int s1 = 0; //, s2 = 0; + for (int i=0; i> 4; + int v3 = x.qs[i+1] & 0xf; + int v4 = x.qs[i+1] >> 4; + int j = 2*i; + s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; + //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; + } + return y.d * x.d * s1 + y.s * x.m; + //return y.d * (x.d * s1 + x.m * s2); +} + +struct Stat { + double sum = 0, sumt = 0, sumt2 = 0, maxt = 0; + int nloop = 0; + void addResult(double s, double t) { + sum += s; + sumt += t; sumt2 += t*t; maxt = std::max(maxt, t); + ++nloop; + } + void reportResult(const char* title) const { + if (nloop < 1) { + printf("%s(%s): no result\n",__func__,title); + return; + } + printf("============ %s\n",title); + printf(" = %g\n",sum/nloop); + auto t = sumt/nloop, dt = sumt2/nloop - t*t; + if (dt > 0) dt = sqrt(dt); + printf("