From 0afc2f91db815bbb1f29d2d29ed8814ab5a1c09b Mon Sep 17 00:00:00 2001 From: John Doe Date: Tue, 2 May 2023 08:29:13 -0400 Subject: [PATCH] check for thread support in quantize-stats --- examples/quantize-stats/quantize-stats.cpp | 29 ++++++++++++++++------ 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 9a2aa7c64..f5898edd5 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -16,8 +16,10 @@ #include #include #include +#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS #include #include +#endif struct quantize_stats_params { std::string model = "models/7B/ggml-model-f16.bin"; @@ -194,7 +196,9 @@ void test_roundtrip_on_layer( if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements); if (output_scratch.size() < nelements) output_scratch.resize(nelements); +#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS if (max_thread < 1) max_thread = std::thread::hardware_concurrency(); +#endif int chunk_size = 32*512; int num_chunks = (nelements + chunk_size - 1)/chunk_size; @@ -203,29 +207,38 @@ void test_roundtrip_on_layer( output_scratch.data(), print_layer_stats ? layer_error : total_error); } else { auto & stats = print_layer_stats ? layer_error : total_error; +#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS std::mutex mutex; +#endif uint64_t counter = 0; - auto compute = [&mutex, &counter, &stats, &qfns, nelements, layer, use_reference, input_scratch_ptr, - &quantized_scratch, &output_scratch, chunk_size] () { + auto compute = [&, nelements, layer, use_reference, input_scratch_ptr, chunk_size] () { error_stats local_stats {}; while (true) { - std::unique_lock lock(mutex); - uint64_t offset = counter; counter += chunk_size; - if (offset >= nelements) { - combine_error_stats(stats, local_stats); - break; + uint64_t offset; + { +#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS + std::unique_lock lock(mutex); +#endif + offset = counter; counter += chunk_size; + if (offset >= nelements) { + combine_error_stats(stats, local_stats); + break; + } } - lock.unlock(); uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset; test_roundtrip_on_chunk(layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset, quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats); } }; +#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS int nthread = std::min(num_chunks, max_thread); std::vector workers(nthread-1); for (auto& w : workers) w = std::thread(compute); compute(); for (auto& w : workers) w.join(); +#else + compute(); +#endif } if (print_layer_stats) {