diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2208f42f7..7e8a29b1e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -81,7 +81,6 @@ jobs: matrix: sanitizer: [ADDRESS, THREAD, UNDEFINED] build_type: [Debug, Release] - accelerate: [ON, OFF] steps: - name: Clone @@ -99,7 +98,7 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DLLAMA_ACCELERATE=${{ matrix.accelerate }} + cmake .. -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} cmake --build . --config ${{ matrix.build_type }} - name: Test diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f9fdd30f..2c3c60167 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,7 +307,7 @@ add_library(ggml OBJECT target_include_directories(ggml PUBLIC .) target_compile_features(ggml PUBLIC c_std_11) # don't bump -target_link_libraries(ggml PRIVATE Threads::Threads ${LLAMA_EXTRA_LIBS}) +target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) if (BUILD_SHARED_LIBS) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() diff --git a/Makefile b/Makefile index 4bf481aa2..3b48eec99 100644 --- a/Makefile +++ b/Makefile @@ -101,11 +101,13 @@ ifdef LLAMA_OPENBLAS LDFLAGS += -lopenblas endif ifdef LLAMA_CUBLAS - CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include - LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64 - OBJS += ggml-cuda.o + CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include + LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 + OBJS += ggml-cuda.o + NVCC = nvcc + NVCCFLAGS = --forward-unknown-to-host-linker -arch=native ggml-cuda.o: ggml-cuda.cu ggml-cuda.h - nvcc -arch=native -c -o $@ $< + $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@ endif ifdef LLAMA_GPROF CFLAGS += -pg diff --git a/SHA256SUMS b/SHA256SUMS index 63fac21ae..1d034b371 100644 --- a/SHA256SUMS +++ b/SHA256SUMS @@ -1,12 +1,27 @@ 700df0d3013b703a806d2ae7f1bfb8e59814e3d06ae78be0c66368a50059f33d models/7B/consolidated.00.pth +666a4bb533b303bdaf89e1b6a3b6f93535d868de31d903afdc20983dc526c847 models/7B/ggml-model-f16.bin +fcb7664c2e69776920b526362a243e912f73c36b1ec892eb354bab940f5edb5a models/7B/ggml-model-q4_0.bin +cc061458339a3eb8bcecbf0a825e9924fb7d1a8150f63cd5d091caa99215aafe models/7B/ggml-model-q4_1.bin +1bc7484c24a87612726d756f1761890e7acf5f412e23378577ce50fbe789b5b8 models/7B/ggml-model-q4_2.bin +3429bf198ec771886cf81a574df45245f3ebf04f0ce0956b73ef5d0ab01ff48b models/7B/ggml-model-q4_3.bin 7e89e242ddc0dd6f060b43ca219ce8b3e8f08959a72cb3c0855df8bb04d46265 models/7B/params.json 745bf4e29a4dd6f411e72976d92b452da1b49168a4f41c951cfcc8051823cf08 models/13B/consolidated.00.pth d5ccbcc465c71c0de439a5aeffebe8344c68a519bce70bc7f9f92654ee567085 models/13B/consolidated.01.pth +2b206e9b21fb1076f11cafc624e2af97c9e48ea09312a0962153acc20d45f808 models/13B/ggml-model-f16.bin +4b69e4d6b6e3275230955997b90407fceca7e5ab3daf2e63a2c9e7270a8e1e3e models/13B/ggml-model-q4_0.bin +d9581b5b88e5622532fe897c9f9b0e67a317d22dd27a6f90fa4ab8c6d23ccdbb models/13B/ggml-model-q4_1.bin +8d55a2077317ec9a928c7851d6a43e08e51f7e9e08360f2a7a7e1deefea3134f models/13B/ggml-model-q4_2.bin +4208cdec9788ffa48dc1a17af2c36a0299f5bf3eb0e2b87889dda7fad591fca3 models/13B/ggml-model-q4_3.bin 4ab77bec4d4405ccb66a97b282574c89a94417e3c32e5f68f37e2876fc21322f models/13B/params.json e23294a58552d8cdec5b7e8abb87993b97ea6eced4178ff2697c02472539d067 models/30B/consolidated.00.pth 4e077b7136c7ae2302e954860cf64930458d3076fcde9443f4d0e939e95903ff models/30B/consolidated.01.pth 24a87f01028cbd3a12de551dcedb712346c0b5cbdeff1454e0ddf2df9b675378 models/30B/consolidated.02.pth 1adfcef71420886119544949767f6a56cb6339b4d5fcde755d80fe68b49de93b models/30B/consolidated.03.pth +7e1b524061a9f4b27c22a12d6d2a5bf13b8ebbea73e99f218809351ed9cf7d37 models/30B/ggml-model-f16.bin +7a679908ce31c9d6ae2e38d6059bcd4d0ad3a870cd58cc1c8f7b36f2b2f51c73 models/30B/ggml-model-q4_0.bin +7b75ac615fa369ee593493a7e6ef87542bf0350255db928b22c5a24f6d598bcd models/30B/ggml-model-q4_1.bin +2c82b4954a94a6a284f452f6011c1e4f0d20362c194a0b1eb5737f5fd8a20fb3 models/30B/ggml-model-q4_2.bin +a6188660199dbcb8d5658abe7d89169869e50423494385830d9e6b330ea7fc33 models/30B/ggml-model-q4_3.bin 2c07118ea98d69dbe7810d88520e30288fa994751b337f8fca02b171955f44cb models/30B/params.json 135c563f6b3938114458183afb01adc9a63bef3d8ff7cccc3977e5d3664ecafe models/65B/consolidated.00.pth 9a600b37b19d38c7e43809485f70d17d1dc12206c07efa83bc72bb498a568bde models/65B/consolidated.01.pth @@ -16,5 +31,10 @@ e7babf7c5606f165a3756f527cb0fedc4f83e67ef1290391e52fb1cce5f26770 models/65B/con a287c0dfe49081626567c7fe87f74cce5831f58e459b427b5e05567641f47b78 models/65B/consolidated.05.pth 72b4eba67a1a3b18cb67a85b70f8f1640caae9b40033ea943fb166bd80a7b36b models/65B/consolidated.06.pth d27f5b0677d7ff129ceacd73fd461c4d06910ad7787cf217b249948c3f3bc638 models/65B/consolidated.07.pth +60758f2384d74e423dffddfd020ffed9d3bb186ebc54506f9c4a787d0f5367b0 models/65B/ggml-model-f16.bin +c671fe1bce71499ac732ec999770ebe53ac486623a7891e42c9dfdb6962d2c64 models/65B/ggml-model-q4_0.bin +4743a28aac3e5f32a6e838a815f51d3779de44fbbe251d745251e66c23c5950f models/65B/ggml-model-q4_1.bin +4a145a210c56982389b1ed34387e0590c3e0d7325fa9be4f2284fe4d244a3633 models/65B/ggml-model-q4_2.bin +305e91a4608b4f627b9b8ad5b4af75187d2684254bfd76dcb9db571618ef293c models/65B/ggml-model-q4_3.bin 999ed1659b469ccc2a941714c0a9656fa571d17c9f7c8c7589817ca90edef51b models/65B/params.json 9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347 models/tokenizer.model diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 48d8657c5..28b724933 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,7 +264,7 @@ int main(int argc, char ** argv) { // infinite text generation via context swapping // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches if (n_past + (int) embd.size() > n_ctx) { const int n_left = n_past - params.n_keep; @@ -282,13 +282,21 @@ int main(int argc, char ** argv) { //printf("\n---\n"); } - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.n_ethreads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; + // evaluate tokens in batches + // embd is typically prepared beforehand to fit within a batch, but not always + for (int i = 0; i < (int) embd.size(); i += params.n_batch) { + int n_eval = (int) embd.size() - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads, params.n_ethreads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + n_past += n_eval; } } - n_past += embd.size(); embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 165fb80d5..4b6118e7f 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -54,7 +54,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) { auto end_t = std::chrono::high_resolution_clock::now(); if (i == 0) { const float seconds = std::chrono::duration(end_t - start_t).count(); - printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0)); + printf("%.2f seconds per pass - ETA ", seconds); + int total_seconds = (int)(seconds * seq_count); + if (total_seconds >= 60*60) { + printf("%d hours ", total_seconds / (60*60)); + total_seconds = total_seconds % (60*60); + } + printf("%d minutes\n", total_seconds / 60); } // We get the logits for all the tokens in the context window (params.n_ctx) // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index cd973e8ac..4e6c2c831 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -15,6 +15,8 @@ #include #include #include +#include +#include struct quantize_stats_params { std::string model = "models/7B/ggml-model-f16.bin"; @@ -27,7 +29,6 @@ struct quantize_stats_params { std::vector include_types; }; -const int64_t SCRATCH_ELEMENTS = 32*32; const size_t HISTOGRAM_BUCKETS = 150; const double HISTOGRAM_RANGE = 0.03; @@ -90,6 +91,13 @@ void update_error_stats(int64_t nelements, const float * input, const float * ou stats.num_samples += nelements; } +void combine_error_stats(error_stats & into, const error_stats & from) { + into.num_samples += from.num_samples; + into.total_error += from.total_error; + if (from.max_error > into.max_error) into.max_error = from.max_error; + for (size_t i=0; inb[3] == tensor->nb[2]*tensor->ne[2]; } +void test_roundtrip_on_chunk( + const ggml_tensor * layer, + int64_t offset, + int64_t chunk_size, + const quantize_fns_t & qfns, + bool use_reference, + float * input_scratch, + char * quantized_scratch, + float * output_scratch, + error_stats & stats) { + + if (layer->type == GGML_TYPE_F16) { + for (int i = 0; i < chunk_size; i++) { + input_scratch[i] = ggml_get_f32_1d(layer, i + offset); + } + } else { + input_scratch = ggml_get_data_f32(layer) + offset; + } + + if (use_reference) { + qfns.quantize_row_q_reference(input_scratch, quantized_scratch, chunk_size); + } else { + qfns.quantize_row_q(input_scratch, quantized_scratch, chunk_size); + } + qfns.dequantize_row_q(quantized_scratch, output_scratch, chunk_size); + + update_error_stats(chunk_size, input_scratch, output_scratch, stats); +} + + // Run quantization function for a single layer and update error stats void test_roundtrip_on_layer( std::string & name, @@ -137,40 +175,61 @@ void test_roundtrip_on_layer( const quantize_fns_t & qfns, bool use_reference, const ggml_tensor * layer, - float * input_scratch, - char *quantized_scratch, - float * output_scratch, - error_stats & total_error) { + std::vector & input_scratch, + std::vector & quantized_scratch, + std::vector & output_scratch, + error_stats & total_error, + int max_thread = 0) { assert(tensor_is_contiguous(layer)); error_stats layer_error {}; - int64_t nelements = ggml_nelements(layer); + uint64_t nelements = ggml_nelements(layer); - for (int64_t offset = 0; offset < nelements; offset += SCRATCH_ELEMENTS) { - int64_t chunk_size = std::min(SCRATCH_ELEMENTS, nelements - offset); - - if (layer->type == GGML_TYPE_F16) { - for (int i = 0; i < chunk_size; i++) { - input_scratch[i] = ggml_get_f32_1d(layer, i + offset); - } - } else { - input_scratch = ggml_get_data_f32(layer) + offset; - } - - if (use_reference) { - qfns.quantize_row_q_reference(input_scratch, quantized_scratch, chunk_size); - } else { - qfns.quantize_row_q(input_scratch, quantized_scratch, chunk_size); - } - qfns.dequantize_row_q(quantized_scratch, output_scratch, chunk_size); - - update_error_stats(chunk_size, input_scratch, output_scratch, total_error); - if (print_layer_stats) { - update_error_stats(chunk_size, input_scratch, output_scratch, layer_error); - } + float* input_scratch_ptr = nullptr; + if (layer->type == GGML_TYPE_F16) { + if (input_scratch.size() < nelements) input_scratch.resize(nelements); + input_scratch_ptr = input_scratch.data(); } + if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements); + if (output_scratch.size() < nelements) output_scratch.resize(nelements); + + if (max_thread < 1) max_thread = std::thread::hardware_concurrency(); + int chunk_size = 32*512; + int num_chunks = (nelements + chunk_size - 1)/chunk_size; + + if (num_chunks < 2 || max_thread < 2) { + test_roundtrip_on_chunk(layer, 0, nelements, qfns, use_reference, input_scratch_ptr, quantized_scratch.data(), + output_scratch.data(), print_layer_stats ? layer_error : total_error); + } else { + auto & stats = print_layer_stats ? layer_error : total_error; + std::mutex mutex; + uint64_t counter = 0; + auto compute = [&mutex, &counter, &stats, &qfns, nelements, layer, use_reference, input_scratch_ptr, + &quantized_scratch, &output_scratch, chunk_size] () { + error_stats local_stats {}; + while (true) { + std::unique_lock lock(mutex); + uint64_t offset = counter; counter += chunk_size; + if (offset >= nelements) { + combine_error_stats(stats, local_stats); + break; + } + lock.unlock(); + uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset; + test_roundtrip_on_chunk(layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset, + quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats); + } + }; + int nthread = std::min(num_chunks, max_thread); + std::vector workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + } + if (print_layer_stats) { print_error_stats(name, layer_error, false); + combine_error_stats(total_error, layer_error); } } @@ -181,6 +240,7 @@ int main(int argc, char ** argv) { // read command line + int max_thread = 0; bool invalid_param = false; std::string arg; for (int i = 1; i < argc; i++) { @@ -230,6 +290,12 @@ int main(int argc, char ** argv) { fprintf(stderr, "error: %s not in list of types\n", argv[i]); invalid_param = true; } + } else if (arg == "-n" || arg == "--num-threads") { + if (++i >= argc) { + invalid_param = true; + break; + } + max_thread = atoi(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); quantize_stats_print_usage(argc, argv); @@ -295,9 +361,9 @@ int main(int argc, char ** argv) { } printf("testing %d layers with max size %" PRId64 "\n", included_layers, max_nelements); // allocate scratch space - std::vector input_scratch(SCRATCH_ELEMENTS); - std::vector quantized_scratch(SCRATCH_ELEMENTS*4); - std::vector output_scratch(SCRATCH_ELEMENTS); + std::vector input_scratch; + std::vector quantized_scratch; + std::vector output_scratch; // loop throught quantization types for (int i = 0; i < GGML_TYPE_COUNT; i++) { @@ -328,10 +394,11 @@ int main(int argc, char ** argv) { qfns, params.reference, kv_tensor.second, - input_scratch.data(), - quantized_scratch.data(), - output_scratch.data(), - global_stats + input_scratch, + quantized_scratch, + output_scratch, + global_stats, + max_thread ); } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 59cb67440..5b4812c62 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -10,11 +10,12 @@ int main(int argc, char ** argv) { ggml_time_init(); - if (argc != 4) { - fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + if (argc < 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type [nthread]\n", argv[0]); fprintf(stderr, " type = %d - q4_0\n", LLAMA_FTYPE_MOSTLY_Q4_0); fprintf(stderr, " type = %d - q4_1\n", LLAMA_FTYPE_MOSTLY_Q4_1); fprintf(stderr, " type = %d - q4_2\n", LLAMA_FTYPE_MOSTLY_Q4_2); + fprintf(stderr, " type = %d - q4_3\n", LLAMA_FTYPE_MOSTLY_Q4_3); return 1; } @@ -29,6 +30,7 @@ int main(int argc, char ** argv) { const std::string fname_out = argv[2]; const enum llama_ftype ftype = (enum llama_ftype)atoi(argv[3]); + int nthread = argc > 4 ? atoi(argv[4]) : 0; const int64_t t_main_start_us = ggml_time_us(); @@ -38,7 +40,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype)) { + if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7cd116602..fa511c1dc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,5 +1,7 @@ #include +#include #include +#include #include "ggml-cuda.h" typedef uint16_t ggml_fp16_t; @@ -22,11 +24,18 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 b #define QK4_2 16 typedef struct { - __half d; // delta + __half d; // delta uint8_t qs[QK4_2 / 2]; // nibbles / quants } block_q4_2; static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding"); +#define QK4_3 16 +typedef struct { + __half d; // delta + __half m; // min + uint8_t qs[QK4_3 / 2]; // nibbles / quants +} block_q4_3; +static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); static __global__ void dequantize_block_q4_0(const void * vx, float * y) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -98,19 +107,122 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) { } } -extern "C" { - __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_0; - dequantize_block_q4_0<<>>(vx, y); - } +static __global__ void dequantize_block_q4_3(const void * vx, float * y) { + const block_q4_3 * x = (const block_q4_3 *) vx; - __host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_1; - dequantize_block_q4_1<<>>(vx, y); - } + const int i = blockIdx.x; - __host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_2; - dequantize_block_q4_2<<>>(vx, y); + const float d = x[i].d; + const float m = x[i].m; + + const uint8_t * pp = x[i].qs; + + for (int l = 0; l < QK4_3; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK4_3 + l + 0] = v0; + y[i*QK4_3 + l + 1] = v1; + } +} + +void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_0; + dequantize_block_q4_0<<>>(vx, y); +} + +void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_1; + dequantize_block_q4_1<<>>(vx, y); +} + +void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_2; + dequantize_block_q4_2<<>>(vx, y); +} + +void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_3; + dequantize_block_q4_3<<>>(vx, y); +} + +// buffer pool for cuda +#define MAX_CUDA_BUFFERS 16 + +struct scoped_spin_lock { + std::atomic_flag& lock; + scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { + while (lock.test_and_set(std::memory_order_acquire)) { + ; // spin + } + } + ~scoped_spin_lock() { + lock.clear(std::memory_order_release); + } + scoped_spin_lock(const scoped_spin_lock&) = delete; + scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; +}; + +struct cuda_buffer { + void * ptr = nullptr; + size_t size = 0; +}; + +static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; +static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; + +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { + scoped_spin_lock lock(g_cuda_pool_lock); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[i]; + if (b.size >= size && b.ptr != nullptr) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + void * ptr; + CUDA_CHECK(cudaMalloc((void **) &ptr, size)); + *actual_size = size; + return ptr; +} + +void ggml_cuda_pool_free(void * ptr, size_t size) { + scoped_spin_lock lock(g_cuda_pool_lock); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + CUDA_CHECK(cudaFree(ptr)); +} + +cublasHandle_t g_cublasH = NULL; +cudaStream_t g_cudaStream = NULL; + +void ggml_init_cublas(void) { + if (g_cublasH == NULL) { + // create cublas handle, bind a stream + CUBLAS_CHECK(cublasCreate(&g_cublasH)); + + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking)); + + CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream)); + + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); } } diff --git a/ggml-cuda.h b/ggml-cuda.h index 646caafc6..370bbc75f 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -1,10 +1,40 @@ +#include +#include + #ifdef __cplusplus extern "C" { #endif +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + exit(1); \ + } \ + } while (0) + +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +extern cublasHandle_t g_cublasH; +extern cudaStream_t g_cudaStream; + +void ggml_init_cublas(void); +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); +void ggml_cuda_pool_free(void * ptr, size_t size); + void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); #ifdef __cplusplus } diff --git a/ggml.c b/ggml.c index 9a3430859..2ea4e68fd 100644 --- a/ggml.c +++ b/ggml.c @@ -148,44 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) { #elif defined(GGML_USE_OPENBLAS) #include #elif defined(GGML_USE_CUBLAS) -#include -#include #include "ggml-cuda.h" - -#define CUDA_CHECK(err) \ - do { \ - cudaError_t err_ = (err); \ - if (err_ != cudaSuccess) { \ - printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - cudaGetErrorString(err_)); \ - exit(1); \ - } \ - } while (0) - -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ - exit(1); \ - } \ - } while (0) - -static cublasHandle_t cublasH = NULL; -static cudaStream_t cudaStream = NULL; -static void init_cublas(void) { - if (cublasH == NULL) { - // create cublas handle, bind a stream - CUBLAS_CHECK(cublasCreate(&cublasH)); - - CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); - - CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); - - // configure logging to stdout - // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); - } -} #endif #undef MIN @@ -467,12 +430,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // quantization // -// AVX routines provided by GH user Const-me -// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600 +#if __AVX__ || __AVX2__ || __AVX512F__ +// Unpack 16 4-bit fields into 16 bytes +// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval +static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi) +{ + // Load 8 bytes from memory + __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi ); + + // Expand bytes into uint16_t values + __m128i bytes = _mm_cvtepu8_epi16( tmp ); + + // Unpack values into individual bytes + const __m128i lowMask = _mm_set1_epi8( 0xF ); + __m128i high = _mm_andnot_si128( lowMask, bytes ); + __m128i low = _mm_and_si128( lowMask, bytes ); + high = _mm_slli_epi16( high, 4 ); + bytes = _mm_or_si128( low, high ); + return bytes; +} + #if __AVX2__ || __AVX512F__ // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytesFromNibbles( const uint8_t* rsi ) +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { // Load 16 bytes from memory __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi ); @@ -503,24 +484,7 @@ static inline __m128i packNibbles( __m256i bytes ) __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); return _mm_packus_epi16( r0, r1 ); } -#elif __AVX__ -static inline __m128i bytesFromNibbles( const uint8_t* rsi ) -{ - // Load 8 bytes from memory - __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi ); - - // Expand bytes into uint16_t values - __m128i bytes = _mm_cvtepu8_epi16( tmp ); - - // Unpack values into individual bytes - const __m128i lowMask = _mm_set1_epi8( 0xF ); - __m128i high = _mm_andnot_si128( lowMask, bytes ); - __m128i low = _mm_and_si128( lowMask, bytes ); - high = _mm_slli_epi16( high, 4 ); - bytes = _mm_or_si128( low, high ); - return bytes; -} - +#else static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh @@ -537,6 +501,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) return _mm_packus_epi16( bytes1, bytes2); } #endif +#endif // __AVX__ || __AVX2__ || __AVX512F__ #if __ARM_NEON @@ -635,7 +600,7 @@ typedef struct { float m; // min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK4_2 16 typedef struct { @@ -644,12 +609,21 @@ typedef struct { } block_q4_2; static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding"); +#define QK4_3 16 +typedef struct { + ggml_fp16_t d; // delta + ggml_fp16_t m; // min + uint8_t qs[QK4_3 / 2]; // nibbles / quants +} block_q4_3; +static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); + #define QK8_0 32 typedef struct { float d; // delta + float s; // d * sum(qs[i]) int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); // reference implementation for deterministic creation of model files @@ -1201,7 +1175,6 @@ static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restri const int nb = k / QK4_2; for (int i = 0; i < nb; i++) { - float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L); y[i].d = GGML_FP32_TO_FP16(scale); @@ -1229,6 +1202,49 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int quantize_row_q4_2_rmse(x, y, k); } +static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) { + assert(k % QK4_3 == 0); + const int nb = k / QK4_3; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK4_3; l++) { + const float v = x[i*QK4_3 + l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + + for (int l = 0; l < QK4_3; l += 2) { + const float v0 = (x[i*QK4_3 + l + 0] - min)*id; + const float v1 = (x[i*QK4_3 + l + 1] - min)*id; + + const uint8_t vi0 = (int) (v0 + 0.5f); + const uint8_t vi1 = (int) (v1 + 0.5f); + + assert(vi0 < 16); + assert(vi1 < 16); + + y[i].qs[l/2] = vi0 | (vi1 << 4); + } + } +} + +static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) { + assert(k % QK4_3 == 0); + + block_q4_3 * restrict y = vy; + + quantize_row_q4_3_reference(x, y, k); +} + // reference implementation for deterministic creation of model files static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { assert(k % QK8_0 == 0); @@ -1247,13 +1263,39 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r y[i].d = d; + int sum = 0; for (int l = 0; l < QK8_0; ++l) { const float v = x[i*QK8_0 + l]*id; y[i].qs[l] = roundf(v); + sum += y[i].qs[l]; } + y[i].s = d * sum; } } +#ifdef __AVX2__ +// There is no better way of doing this? +// I guess not, AVX is not very good at horizontal sums. +// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly +// faster than the solution below. As I don't have an AVX2 system handt right now to test, +// keeping the original. +// TODO: Please try and if it does make a differece, uncomment and remove the implementation below. +//static inline float horizontal_sum(__m256i a) { +// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a))); +// __m256i sum = _mm256_add_epi32(a, b); +// __m256i hi = _mm256_unpackhi_epi64(sum, sum); +// sum = _mm256_add_epi32(sum, hi); +// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); +//} +static inline float horizontal_sum(__m256i a) { + __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); + __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + __m128i sum64 = _mm_add_epi32(hi64, sum128); + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} +#endif + static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1280,6 +1322,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].d = d; + int32x4_t accv = vdupq_n_s32(0); + for (int l = 0; l < 8; l++) { const float32x4_t v = vmulq_n_f32(srcv[l], id); const int32x4_t vi = vcvtnq_s32_f32(v); @@ -1288,7 +1332,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); } + int32_t sum = vaddvq_s32(accv); + y[i].s = d * sum; } #elif defined(__AVX2__) || defined(__AVX__) for (int i = 0; i < nb; i++) { @@ -1336,6 +1384,10 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m256i i3 = _mm256_cvtps_epi32( v3 ); #if defined(__AVX2__) + + // Compute the sum of the quants and set y[i].s + y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 @@ -1378,6 +1430,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // scalar quantize_row_q8_0_reference(x, y, k); #endif +#if defined __AVX__ + // TODO: vectorize this + for (int i=0; i> 4; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK4_3 + l + 0] = v0; + y[i*QK4_3 + l + 1] = v1; + + assert(!isnan(y[i*QK4_3 + l + 0])); + assert(!isnan(y[i*QK4_3 + l + 1])); + } + } +} + static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = { @@ -1659,6 +1750,13 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_dot = quantize_row_q8_0, .vec_dot_q = ggml_vec_dot_q4_2_q8_0, }, + [GGML_TYPE_Q4_3] = { + .dequantize_row_q = dequantize_row_q4_3, + .quantize_row_q = quantize_row_q4_3, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q4_3_q8_0, + }, [GGML_TYPE_Q8_0] = { .dequantize_row_q = NULL, // TODO .quantize_row_q = quantize_row_q8_0, @@ -2282,14 +2380,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); + float sum8 = 0; + for (int i = 0; i < nb; i += 2) { const block_q4_0 * restrict x0 = &x[i + 0]; const block_q4_0 * restrict x1 = &x[i + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; + sum8 += x0->d * y0->s + x1->d * y1->s; + const uint8x16_t m4b = vdupq_n_u8(0xf); - const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2300,12 +2401,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); @@ -2320,21 +2415,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); @@ -2346,7 +2441,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2356,7 +2451,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * /* Compute combined scale for the block */ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); - __m256i bx = bytesFromNibbles(x[i].qs); + __m256i bx = bytes_from_nibbles_32(x[i].qs); // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. const __m256i off = _mm256_set1_epi8( 8 ); @@ -2402,7 +2497,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * __m128i i32[2]; for (int j = 0; j < 2; ++j) { // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes - __m128i bx = bytesFromNibbles( x[i].qs + 8*j ); + __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j); __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j)); // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. @@ -2479,12 +2574,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); + float summs = 0; + for (int i = 0; i < nb; i += 2) { const block_q4_1 * restrict x0 = &x[i + 0]; const block_q4_1 * restrict x1 = &x[i + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; + summs += x0->m * y0->s + x1->m * y1->s; + const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t v0_0 = vld1q_u8(x0->qs); @@ -2508,17 +2607,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); - const int16x8_t s0i = vaddq_s16( - vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))), - vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs)))); - - const int16x8_t s1i = vaddq_s16( - vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))), - vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs)))); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d); - #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); @@ -2547,27 +2635,29 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); + float summs = 0; + // Main loop for (int i = 0; i < nb; ++i) { const float * d0 = &x[i].d; const float * d1 = &y[i].d; - const float * m0 = &x[i].m; + //const float * m0 = &x[i].m; + + summs += x[i].m * y[i].s; const __m256 d0v = _mm256_broadcast_ss( d0 ); const __m256 d1v = _mm256_broadcast_ss( d1 ); - const __m256 m0v = _mm256_broadcast_ss( m0 ); // Compute combined scales const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - const __m256 d1m0 = _mm256_mul_ps( d1v, m0v ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i bx = bytesFromNibbles( x[i].qs ); + const __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); // Get absolute values of x vectors @@ -2586,15 +2676,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * // Accumulate d0*d1*x*y acc = _mm256_fmadd_ps( d0d1, xy, acc ); - - // Compute sum of y values - const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones ); - const __m256 ysum = _mm256_cvtepi32_ps( ysumi ); - - // Accumulate d1*m0*y - acc = _mm256_fmadd_ps( d1m0, ysum, acc ); } // Return horizontal sum of the acc vector @@ -2603,7 +2684,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - sumf = _mm_cvtss_f32( res ); + sumf = _mm_cvtss_f32( res ) + summs; #else // scalar for (int i = 0; i < nb; i++) { @@ -2653,6 +2734,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1]; const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0]; const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; @@ -2721,6 +2803,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d)); + + __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + __m256i bx = _mm256_set_m128i(bx1, bx0); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8(8); + bx = _mm256_sub_epi8(bx, off); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(bx, bx); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(by, bx); + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + + const __m256i ones = _mm256_set1_epi16(1); + __m256i xy_q = _mm256_madd_epi16(ones, dot); + + /* Convert to vectore of 8 int32_t to 8 floats */ + __m256 q = _mm256_cvtepi32_ps(xy_q); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps(acc, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(acc)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + + sumf = _mm_cvtss_f32(res); #else // scalar for (int i = 0; i < nb; i++) { @@ -2762,6 +2889,154 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * *s = sumf; } +static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_0; + + assert(n % QK8_0 == 0); + assert(nb % 2 == 0); + assert(QK8_0 == 2*QK4_2); + + const block_q4_3 * restrict x = vx; + const block_q8_0 * restrict y = vy; + + float sumf = 0.0; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i += 2) { + const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; + const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1]; + const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0]; + const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1]; + + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + + const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); + const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); + const float x1_0d = GGML_FP16_TO_FP32(x1_0->d); + const float x1_1d = GGML_FP16_TO_FP32(x1_1->d); + + const float x0_0m = GGML_FP16_TO_FP32(x0_0->m); + const float x0_1m = GGML_FP16_TO_FP32(x0_1->m); + const float x1_0m = GGML_FP16_TO_FP32(x1_0->m); + const float x1_1m = GGML_FP16_TO_FP32(x1_1->m); + + const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); + const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs)); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // interleave + const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); + const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); + const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h); + const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l))); + const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h))); + + const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l))); + const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h))); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d); +#endif + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#else + // scalar + for (int i = 0; i < nb; i++) { + const uint8_t * restrict x0 = x[2*i + 0].qs; + const uint8_t * restrict x1 = x[2*i + 1].qs; + const int8_t * restrict y0 = y[i].qs; + + const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); + const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); + const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); + const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); + + int sy_0 = 0; + int sy_1 = 0; + + int sxy_0 = 0; + int sxy_1 = 0; + + for (int j = 0; j < QK8_0/4; j++) { + const uint8_t v0 = x0[j]; + const uint8_t v1 = x1[j]; + + const int x0_0 = v0 & 0xf; + const int x1_0 = v0 >> 4; + + const int x0_1 = v1 & 0xf; + const int x1_1 = v1 >> 4; + + const int y0_0 = y0[2*j + 0]; + const int y1_0 = y0[2*j + 1]; + + const int y0_1 = y0[2*(j + QK8_0/4) + 0]; + const int y1_1 = y0[2*(j + QK8_0/4) + 1]; + + sy_0 += y0_0 + y1_0; + sy_1 += y0_1 + y1_1; + + sxy_0 += x0_0*y0_0 + x1_0*y1_0; + sxy_1 += x0_1*y0_1 + x1_1*y1_1; + } + + sumf += (d0*sxy_0 + m0*sy_0)*y[i].d; + sumf += (d1*sxy_1 + m1*sy_1)*y[i].d; + } +#endif + + *s = sumf; +} + + // compute GGML_VEC_DOT_UNROLL dot products at once // xs - x row stride in bytes inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { @@ -3009,12 +3284,13 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = QK4_0, [GGML_TYPE_Q4_1] = QK4_1, [GGML_TYPE_Q4_2] = QK4_2, + [GGML_TYPE_Q4_3] = QK4_3, [GGML_TYPE_Q8_0] = QK8_0, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 9, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3022,12 +3298,13 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = sizeof(block_q4_0), [GGML_TYPE_Q4_1] = sizeof(block_q4_1), [GGML_TYPE_Q4_2] = sizeof(block_q4_2), + [GGML_TYPE_Q4_3] = sizeof(block_q4_3), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3036,12 +3313,13 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = "q4_0", [GGML_TYPE_Q4_1] = "q4_1", [GGML_TYPE_Q4_2] = "q4_2", + [GGML_TYPE_Q4_3] = "q4_3", [GGML_TYPE_Q8_0] = "q8_0", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3049,12 +3327,13 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = true, [GGML_TYPE_Q4_1] = true, [GGML_TYPE_Q4_2] = true, + [GGML_TYPE_Q4_3] = true, [GGML_TYPE_Q8_0] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 9, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -3316,7 +3595,7 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct (t0->ne[3] == t1->ne[3]); } -static inline bool ggml_is_quantized(enum ggml_type type) { +bool ggml_is_quantized(enum ggml_type type) { return GGML_IS_QUANTIZED[type]; } @@ -3432,7 +3711,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize cuBLAS #if defined(GGML_USE_CUBLAS) - init_cublas(); + ggml_init_cublas(); #endif is_first_call = false; @@ -5852,7 +6131,6 @@ static void ggml_compute_forward_dup_f32( i10 += ne00 * ir0; while (i10 >= ne0) { i10 -= ne0; - i11++; if (++i11 == ne1) { i11 = 0; if (++i12 == ne2) { @@ -6266,6 +6544,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_3: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -7278,18 +7557,16 @@ static void ggml_compute_forward_mul_mat_f32( } #if defined(GGML_USE_CUBLAS) - float *d_X = NULL; - float *d_Y = NULL; - float *d_D = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); + size_t x_size, y_size, d_size; + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #endif for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -7301,19 +7578,19 @@ static void ggml_compute_forward_mul_mat_f32( #if defined(GGML_USE_CUBLAS) // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute CUBLAS_CHECK( - cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha, d_X, ne00, d_Y, ne10, &beta, d_D, ne01)); // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, @@ -7325,10 +7602,10 @@ static void ggml_compute_forward_mul_mat_f32( } } #if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - CUDA_CHECK(cudaFree(d_X)); - CUDA_CHECK(cudaFree(d_Y)); - CUDA_CHECK(cudaFree(d_D)); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); #endif //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); @@ -7478,18 +7755,16 @@ static void ggml_compute_forward_mul_mat_f16_f32( #if defined(GGML_USE_CUBLAS) ggml_fp16_t * const wdata = params->wdata; - float *d_X = NULL; - float *d_Y = NULL; - float *d_D = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); + size_t x_size, y_size, d_size; + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #else float * const wdata = params->wdata; #endif @@ -7523,12 +7798,12 @@ static void ggml_compute_forward_mul_mat_f16_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream)); - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute CUBLAS_CHECK( - cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha, d_X, CUDA_R_16F, ne00, d_Y, CUDA_R_16F, ne10, @@ -7537,7 +7812,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( CUBLAS_GEMM_DEFAULT)); // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); @@ -7555,10 +7830,10 @@ static void ggml_compute_forward_mul_mat_f16_f32( } #if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - CUDA_CHECK(cudaFree(d_X)); - CUDA_CHECK(cudaFree(d_Y)); - CUDA_CHECK(cudaFree(d_D)); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); #endif /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ @@ -7726,20 +8001,17 @@ static void ggml_compute_forward_mul_mat_q_f32( } #if defined(GGML_USE_CUBLAS) - float *d_X = NULL; - float *d_Y = NULL; - float *d_D = NULL; - float *d_Q = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); - CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type])); + size_t x_size, y_size, d_size, q_size; + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL; if (type == GGML_TYPE_Q4_0) { @@ -7769,9 +8041,9 @@ static void ggml_compute_forward_mul_mat_q_f32( // copy and dequantize on device CUDA_CHECK( cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02, - GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream)); + GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream)); - dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream); + dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); CUDA_CHECK(cudaGetLastError()); #else { @@ -7787,18 +8059,18 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_CUBLAS) // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute CUBLAS_CHECK( - cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha, d_X, ne00, d_Y, ne10, &beta, d_D, ne01)); // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, @@ -7811,11 +8083,11 @@ static void ggml_compute_forward_mul_mat_q_f32( } #if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - CUDA_CHECK(cudaFree(d_X)); - CUDA_CHECK(cudaFree(d_Y)); - CUDA_CHECK(cudaFree(d_D)); - CUDA_CHECK(cudaFree(d_Q)); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); + ggml_cuda_pool_free(d_Q, q_size); #endif //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); @@ -7905,6 +8177,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_3: case GGML_TYPE_Q8_0: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); @@ -7922,34 +8195,6 @@ static void ggml_compute_forward_mul_mat( GGML_ASSERT(false); } break; } - -#if 0 - if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) { - static int first = 8; - printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); - printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); - printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - if (first) { - --first; - } else { - for (int k = 0; k < dst->ne[1]; ++k) { - for (int j = 0; j < dst->ne[0]/16; ++j) { - for (int i = 0; i < 16; ++i) { - printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - } - printf("\n"); - } - printf("\n"); - } - printf("\n"); - exit(0); - } - } else { - printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); - printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); - printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - } -#endif } // ggml_compute_forward_scale @@ -8161,6 +8406,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_3: case GGML_TYPE_Q8_0: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); @@ -8391,9 +8637,11 @@ static void ggml_compute_forward_rope_f32( const float theta_scale = powf(10000.0, -2.0f/n_dims); + const bool is_neox = mode & 2; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int p = (mode == 0 ? n_past + i2 : i2); + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int p = ((mode & 1) == 0 ? n_past + i2 : i2); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -8406,14 +8654,25 @@ static void ggml_compute_forward_rope_f32( theta *= theta_scale; - const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + if (!is_neox) { + const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const float x0 = src[0]; - const float x1 = src[1]; + const float x0 = src[0]; + const float x1 = src[1]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } else { + const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } } } } @@ -8468,9 +8727,11 @@ static void ggml_compute_forward_rope_f16( const float theta_scale = powf(10000.0, -2.0f/n_dims); + const bool is_neox = mode & 2; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int p = (mode == 0 ? n_past + i2 : i2); + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int p = ((mode & 1) == 0 ? n_past + i2 : i2); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -8483,14 +8744,25 @@ static void ggml_compute_forward_rope_f16( theta *= theta_scale; - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + if (!is_neox) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[1]); + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[1]); - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } else { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } } } } @@ -11900,6 +12172,62 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_2*sizeof(block_q4_2)); } +size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK4_3 == 0); + const int nb = k / QK4_3; + + for (int j = 0; j < n; j += k) { + block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3; + + quantize_row_q4_3_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK4_3; l += 2) { + const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi1 = y[i].qs[l/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK4_3*sizeof(block_q4_3)); +} + +size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { + size_t result = 0; + switch (type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(start % QK4_0 == 0); + block_q4_0 * block = (block_q4_0*)dst + start / QK4_0; + result = ggml_quantize_q4_0(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(start % QK4_1 == 0); + block_q4_1 * block = (block_q4_1*)dst + start / QK4_1; + result = ggml_quantize_q4_1(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_2: + { + GGML_ASSERT(start % QK4_2 == 0); + block_q4_2 * block = (block_q4_2*)dst + start / QK4_2; + result = ggml_quantize_q4_2(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_3: + { + GGML_ASSERT(start % QK4_3 == 0); + block_q4_3 * block = (block_q4_3*)dst + start / QK4_3; + result = ggml_quantize_q4_3(src + start, block, n, n, hist); + } break; + default: + assert(false); + } + return result; +} + //////////////////////////////////////////////////////////////////////////////// int ggml_cpu_has_avx(void) { diff --git a/ggml.h b/ggml.h index 570147fc2..460d4ffe0 100644 --- a/ggml.h +++ b/ggml.h @@ -205,7 +205,8 @@ enum ggml_type { GGML_TYPE_Q4_0 = 2, GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_2 = 4, - GGML_TYPE_Q8_0 = 5, + GGML_TYPE_Q4_3 = 5, + GGML_TYPE_Q8_0 = 6, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -360,6 +361,8 @@ const char * ggml_type_name(enum ggml_type type); size_t ggml_element_size(const struct ggml_tensor * tensor); +bool ggml_is_quantized(enum ggml_type type); + struct ggml_context * ggml_init(struct ggml_init_params params); void ggml_free(struct ggml_context * ctx); @@ -627,7 +630,8 @@ struct ggml_tensor * ggml_soft_max( // rotary position embedding // in-place, returns view(a) -// if mode == 1, skip n_past elements +// if mode & 1 == 1, skip n_past elements +// if mode & 2 == 1, GPT-NeoX style // TODO: avoid creating a new tensor every time struct ggml_tensor * ggml_rope( struct ggml_context * ctx, @@ -808,6 +812,9 @@ enum ggml_opt_result ggml_opt( size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist); + +size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); // // system info diff --git a/llama.cpp b/llama.cpp index 0a764a367..7829c5091 100644 --- a/llama.cpp +++ b/llama.cpp @@ -24,6 +24,9 @@ #include #include #include +#include +#include +#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -479,6 +482,7 @@ struct llama_file_loader { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_3: break; default: { throw format("unrecognized tensor type %u\n", shard.type); @@ -552,6 +556,7 @@ struct llama_file_saver { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_3: break; default: LLAMA_ASSERT(false); } @@ -841,6 +846,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: return "mostly Q4_1, some F16"; case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2"; + case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3"; default: return "unknown, may not work"; } } @@ -1571,15 +1577,20 @@ static llama_vocab::id llama_sample_top_p_top_k( // quantization // -static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum llama_ftype ftype) { +static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum llama_ftype ftype, int nthread) { ggml_type quantized_type; switch (ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break; + case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break; default: throw format("invalid output file type %d\n", ftype); }; + if (nthread <= 0) { + nthread = std::thread::hardware_concurrency(); + } + std::unique_ptr model_loader(new llama_model_loader(fname_inp.c_str(), /*use_mmap*/ false, /*vocab_only*/ false)); llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), ftype); @@ -1588,6 +1599,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s size_t total_size_new = 0; std::vector hist_all(1 << 4, 0); + std::vector workers; + std::mutex mutex; + size_t idx = 0; for (llama_load_tensor & tensor : model_loader->tensors_map.tensors) { llama_buffer read_data; @@ -1606,6 +1620,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // quantize only 2D tensors quantize &= (tensor.ne.size() == 2); + // uncomment this to keep the output layer in FP16 + //if (tensor.name == "output.weight") { + // quantize = false; + //} + enum ggml_type new_type; void * new_data; size_t new_size; @@ -1641,21 +1660,37 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_data = work.addr; std::vector hist_cur(1 << 4, 0); - switch (new_type) { - case GGML_TYPE_Q4_0: - { - new_size = ggml_quantize_q4_0(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); - } break; - case GGML_TYPE_Q4_1: - { - new_size = ggml_quantize_q4_1(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); - } break; - case GGML_TYPE_Q4_2: - { - new_size = ggml_quantize_q4_2(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); - } break; - default: - LLAMA_ASSERT(false); + int chunk_size = 32 * 512; + const int nchunk = (nelements + chunk_size - 1)/chunk_size; + const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; + if (nthread_use < 2) { + new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data()); + } else { + size_t counter = 0; + new_size = 0; + auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements, chunk_size] () { + std::vector local_hist; + size_t local_size = 0; + while (true) { + std::unique_lock lock(mutex); + size_t first = counter; counter += chunk_size; + if (first >= nelements) { + if (!local_hist.empty()) { + for (int j=0; j %8.2f MB | hist: ", tensor.size/1024.0/1024.0, new_size/1024.0/1024.0); @@ -1777,9 +1812,10 @@ void llama_free(struct llama_context * ctx) { int llama_model_quantize( const char * fname_inp, const char * fname_out, - enum llama_ftype ftype) { + enum llama_ftype ftype, + int nthread) { try { - llama_model_quantize_internal(fname_inp, fname_out, ftype); + llama_model_quantize_internal(fname_inp, fname_out, ftype, nthread); return 0; } catch (const std::string & err) { fprintf(stderr, "%s: failed to quantize: %s\n", __func__, err.c_str()); @@ -1965,7 +2001,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * base_t = dest_t; } - if (base_t->type == GGML_TYPE_Q4_0 || base_t->type == GGML_TYPE_Q4_1 || base_t->type == GGML_TYPE_Q4_2) { + if (ggml_is_quantized(base_t->type)) { if (!warned) { fprintf(stderr, "%s: warning: using a lora adapter with a quantized model may result in poor quality, " "use a f16 or f32 base model with --lora-base\n", __func__); @@ -2058,7 +2094,11 @@ void llama_set_kv_cache( int n_token_count) { // Make sure we have the same kv cache setup LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size); + void * k_data = ctx->model.kv_self.k->data; // remember data pointers + void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size); + ctx->model.kv_self.k->data = k_data; // restore correct data pointers + ctx->model.kv_self.v->data = v_data; ctx->model.kv_self.n = n_token_count; } diff --git a/llama.h b/llama.h index 6a5bcb972..b31ac9aeb 100644 --- a/llama.h +++ b/llama.h @@ -73,6 +73,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params(); @@ -92,10 +93,12 @@ extern "C" { // TODO: not great API - very likely to change // Returns 0 on success + // nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given LLAMA_API int llama_model_quantize( const char * fname_inp, const char * fname_out, - enum llama_ftype ftype); + enum llama_ftype ftype, + int nthread); // Apply a LoRA adapter to a loaded model // path_base_model is the path to a higher quality model to use as a base for diff --git a/llama_util.h b/llama_util.h index eba14656a..acb207e65 100755 --- a/llama_util.h +++ b/llama_util.h @@ -21,6 +21,9 @@ #if defined(_POSIX_MAPPED_FILES) #include #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif #endif #endif @@ -303,8 +306,18 @@ struct llama_mlock { if (!mlock(addr, size)) { return true; } else { - fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n" MLOCK_SUGGESTION, - size, this->size, std::strerror(errno)); + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); + + // Check if the resource limit is fine after all + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) + suggest = false; + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) + suggest = false; + + fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); return false; } } diff --git a/pocs/vdot/CMakeLists.txt b/pocs/vdot/CMakeLists.txt index cbc852236..fb89a1cd4 100644 --- a/pocs/vdot/CMakeLists.txt +++ b/pocs/vdot/CMakeLists.txt @@ -2,3 +2,8 @@ set(TARGET vdot) add_executable(${TARGET} vdot.cpp) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) + +set(TARGET q8dot) +add_executable(${TARGET} q8dot.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/pocs/vdot/q8dot.cpp b/pocs/vdot/q8dot.cpp new file mode 100644 index 000000000..5748c8ac2 --- /dev/null +++ b/pocs/vdot/q8dot.cpp @@ -0,0 +1,172 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +constexpr int kVecSize = 1 << 16; + +// Copy-pasted from ggml.c +#define QK4_0 32 +typedef struct { + float d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + float d; // delta + float m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +// Copy-pasted from ggml.c +#define QK8_0 32 +typedef struct { + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); + +static_assert(QK4_1 == QK8_0, "QK4_1 and QK8_0 must be the same"); +static_assert(QK4_0 == QK8_0, "QK4_0 and QK8_0 must be the same"); + +template +void fillQ4blocks(std::vector& blocks, std::mt19937& rndm) { + for (auto& b : blocks) { + b.d = 1; + for (int i=0; i> 28; + uint8_t v2 = rndm() >> 28; + b.qs[i] = v1 | (v2 << 4); + } + } +} + +void fillQ80blocks(std::vector& blocks, std::mt19937& rndm) { + for (auto& b : blocks) { + b.d = 1; + int sum = 0; + for (int i=0; i> 24) - 128; + sum += b.qs[i]; + } + b.s = b.d * sum; + } +} + +float simpleDot(const block_q4_0& x, const block_q8_0& y) { + int s1 = 0; //, s2 = 0; + for (int i=0; i> 4; + int v3 = x.qs[i+1] & 0xf; + int v4 = x.qs[i+1] >> 4; + int j = 2*i; + s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; + //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; + } + return y.d * x.d * s1 - 8 * x.d * y.s; + //return y.d * x.d * (s1 - 8 * s2); +} + +float simpleDot(const block_q4_1& x, const block_q8_0& y) { + int s1 = 0; //, s2 = 0; + for (int i=0; i> 4; + int v3 = x.qs[i+1] & 0xf; + int v4 = x.qs[i+1] >> 4; + int j = 2*i; + s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; + //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; + } + return y.d * x.d * s1 + y.s * x.m; + //return y.d * (x.d * s1 + x.m * s2); +} + +struct Stat { + double sum = 0, sumt = 0, sumt2 = 0, maxt = 0; + int nloop = 0; + void addResult(double s, double t) { + sum += s; + sumt += t; sumt2 += t*t; maxt = std::max(maxt, t); + ++nloop; + } + void reportResult(const char* title) const { + if (nloop < 1) { + printf("%s(%s): no result\n",__func__,title); + return; + } + printf("============ %s\n",title); + printf(" = %g\n",sum/nloop); + auto t = sumt/nloop, dt = sumt2/nloop - t*t; + if (dt > 0) dt = sqrt(dt); + printf("