check for thread support in quantize-stats
This commit is contained in:
parent
14fa3d108b
commit
0afc2f91db
1 changed files with 21 additions and 8 deletions
|
@ -16,8 +16,10 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#endif
|
||||||
|
|
||||||
struct quantize_stats_params {
|
struct quantize_stats_params {
|
||||||
std::string model = "models/7B/ggml-model-f16.bin";
|
std::string model = "models/7B/ggml-model-f16.bin";
|
||||||
|
@ -194,7 +196,9 @@ void test_roundtrip_on_layer(
|
||||||
if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements);
|
if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements);
|
||||||
if (output_scratch.size() < nelements) output_scratch.resize(nelements);
|
if (output_scratch.size() < nelements) output_scratch.resize(nelements);
|
||||||
|
|
||||||
|
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||||
if (max_thread < 1) max_thread = std::thread::hardware_concurrency();
|
if (max_thread < 1) max_thread = std::thread::hardware_concurrency();
|
||||||
|
#endif
|
||||||
int chunk_size = 32*512;
|
int chunk_size = 32*512;
|
||||||
int num_chunks = (nelements + chunk_size - 1)/chunk_size;
|
int num_chunks = (nelements + chunk_size - 1)/chunk_size;
|
||||||
|
|
||||||
|
@ -203,29 +207,38 @@ void test_roundtrip_on_layer(
|
||||||
output_scratch.data(), print_layer_stats ? layer_error : total_error);
|
output_scratch.data(), print_layer_stats ? layer_error : total_error);
|
||||||
} else {
|
} else {
|
||||||
auto & stats = print_layer_stats ? layer_error : total_error;
|
auto & stats = print_layer_stats ? layer_error : total_error;
|
||||||
|
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
|
#endif
|
||||||
uint64_t counter = 0;
|
uint64_t counter = 0;
|
||||||
auto compute = [&mutex, &counter, &stats, &qfns, nelements, layer, use_reference, input_scratch_ptr,
|
auto compute = [&, nelements, layer, use_reference, input_scratch_ptr, chunk_size] () {
|
||||||
&quantized_scratch, &output_scratch, chunk_size] () {
|
|
||||||
error_stats local_stats {};
|
error_stats local_stats {};
|
||||||
while (true) {
|
while (true) {
|
||||||
std::unique_lock<std::mutex> lock(mutex);
|
uint64_t offset;
|
||||||
uint64_t offset = counter; counter += chunk_size;
|
{
|
||||||
if (offset >= nelements) {
|
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||||
combine_error_stats(stats, local_stats);
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
break;
|
#endif
|
||||||
|
offset = counter; counter += chunk_size;
|
||||||
|
if (offset >= nelements) {
|
||||||
|
combine_error_stats(stats, local_stats);
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
lock.unlock();
|
|
||||||
uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset;
|
uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset;
|
||||||
test_roundtrip_on_chunk(layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset,
|
test_roundtrip_on_chunk(layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset,
|
||||||
quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats);
|
quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#if __STDCPP_THREADS__ || _GLIBCXX_HAS_GTHREADS
|
||||||
int nthread = std::min(num_chunks, max_thread);
|
int nthread = std::min(num_chunks, max_thread);
|
||||||
std::vector<std::thread> workers(nthread-1);
|
std::vector<std::thread> workers(nthread-1);
|
||||||
for (auto& w : workers) w = std::thread(compute);
|
for (auto& w : workers) w = std::thread(compute);
|
||||||
compute();
|
compute();
|
||||||
for (auto& w : workers) w.join();
|
for (auto& w : workers) w.join();
|
||||||
|
#else
|
||||||
|
compute();
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
if (print_layer_stats) {
|
if (print_layer_stats) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue