check for thread support in quantize-stats
This commit is contained in:
parent
14fa3d108b
commit
0afc2f91db
1 changed files with 21 additions and 8 deletions
|
@ -16,8 +16,10 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#endif
|
||||
|
||||
struct quantize_stats_params {
|
||||
std::string model = "models/7B/ggml-model-f16.bin";
|
||||
|
@ -194,7 +196,9 @@ void test_roundtrip_on_layer(
|
|||
if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements);
|
||||
if (output_scratch.size() < nelements) output_scratch.resize(nelements);
|
||||
|
||||
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||
if (max_thread < 1) max_thread = std::thread::hardware_concurrency();
|
||||
#endif
|
||||
int chunk_size = 32*512;
|
||||
int num_chunks = (nelements + chunk_size - 1)/chunk_size;
|
||||
|
||||
|
@ -203,29 +207,38 @@ void test_roundtrip_on_layer(
|
|||
output_scratch.data(), print_layer_stats ? layer_error : total_error);
|
||||
} else {
|
||||
auto & stats = print_layer_stats ? layer_error : total_error;
|
||||
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||
std::mutex mutex;
|
||||
#endif
|
||||
uint64_t counter = 0;
|
||||
auto compute = [&mutex, &counter, &stats, &qfns, nelements, layer, use_reference, input_scratch_ptr,
|
||||
&quantized_scratch, &output_scratch, chunk_size] () {
|
||||
auto compute = [&, nelements, layer, use_reference, input_scratch_ptr, chunk_size] () {
|
||||
error_stats local_stats {};
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
uint64_t offset = counter; counter += chunk_size;
|
||||
if (offset >= nelements) {
|
||||
combine_error_stats(stats, local_stats);
|
||||
break;
|
||||
uint64_t offset;
|
||||
{
|
||||
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
#endif
|
||||
offset = counter; counter += chunk_size;
|
||||
if (offset >= nelements) {
|
||||
combine_error_stats(stats, local_stats);
|
||||
break;
|
||||
}
|
||||
}
|
||||
lock.unlock();
|
||||
uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset;
|
||||
test_roundtrip_on_chunk(layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset,
|
||||
quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats);
|
||||
}
|
||||
};
|
||||
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||
int nthread = std::min(num_chunks, max_thread);
|
||||
std::vector<std::thread> workers(nthread-1);
|
||||
for (auto& w : workers) w = std::thread(compute);
|
||||
compute();
|
||||
for (auto& w : workers) w.join();
|
||||
#else
|
||||
compute();
|
||||
#endif
|
||||
}
|
||||
|
||||
if (print_layer_stats) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue