check for thread support in quantize-stats

This commit is contained in:
John Doe 2023-05-02 08:29:13 -04:00
parent 14fa3d108b
commit 0afc2f91db

View file

@ -16,8 +16,10 @@
#include <string>
#include <unordered_map>
#include <vector>
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
#include <thread>
#include <mutex>
#endif
struct quantize_stats_params {
std::string model = "models/7B/ggml-model-f16.bin";
@ -194,7 +196,9 @@ void test_roundtrip_on_layer(
if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements);
if (output_scratch.size() < nelements) output_scratch.resize(nelements);
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
if (max_thread < 1) max_thread = std::thread::hardware_concurrency();
#endif
int chunk_size = 32*512;
int num_chunks = (nelements + chunk_size - 1)/chunk_size;
@ -203,29 +207,38 @@ void test_roundtrip_on_layer(
output_scratch.data(), print_layer_stats ? layer_error : total_error);
} else {
auto & stats = print_layer_stats ? layer_error : total_error;
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
std::mutex mutex;
#endif
uint64_t counter = 0;
auto compute = [&mutex, &counter, &stats, &qfns, nelements, layer, use_reference, input_scratch_ptr,
&quantized_scratch, &output_scratch, chunk_size] () {
auto compute = [&, nelements, layer, use_reference, input_scratch_ptr, chunk_size] () {
error_stats local_stats {};
while (true) {
std::unique_lock<std::mutex> lock(mutex);
uint64_t offset = counter; counter += chunk_size;
if (offset >= nelements) {
combine_error_stats(stats, local_stats);
break;
uint64_t offset;
{
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
std::unique_lock<std::mutex> lock(mutex);
#endif
offset = counter; counter += chunk_size;
if (offset >= nelements) {
combine_error_stats(stats, local_stats);
break;
}
}
lock.unlock();
uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset;
test_roundtrip_on_chunk(layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset,
quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats);
}
};
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
int nthread = std::min(num_chunks, max_thread);
std::vector<std::thread> workers(nthread-1);
for (auto& w : workers) w = std::thread(compute);
compute();
for (auto& w : workers) w.join();
#else
compute();
#endif
}
if (print_layer_stats) {