initial import of hpx support
This commit is contained in:
parent
799fc22689
commit
776f5e29cd
3 changed files with 179 additions and 1 deletions
|
@ -95,6 +95,7 @@ option(LLAMA_CLBLAST "llama: use CLBlast"
|
||||||
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
|
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
|
||||||
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
|
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
|
||||||
option(LLAMA_MPI "llama: use MPI" OFF)
|
option(LLAMA_MPI "llama: use MPI" OFF)
|
||||||
|
option(LLAMA_HPX "llama: use HPX" OFF)
|
||||||
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
|
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
|
||||||
|
|
||||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||||
|
@ -320,6 +321,10 @@ if (LLAMA_CUBLAS)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(LLAMA_MPI AND LLAMA_HPX)
|
||||||
|
message(FATAL "MPI and HPX are not currently compatible together")
|
||||||
|
endif()
|
||||||
|
|
||||||
if (LLAMA_MPI)
|
if (LLAMA_MPI)
|
||||||
cmake_minimum_required(VERSION 3.10)
|
cmake_minimum_required(VERSION 3.10)
|
||||||
find_package(MPI)
|
find_package(MPI)
|
||||||
|
@ -344,6 +349,17 @@ if (LLAMA_MPI)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_HPX)
|
||||||
|
cmake_minimum_required(VERSION 3.10)
|
||||||
|
find_package (HPX)
|
||||||
|
if (HPX_FOUND)
|
||||||
|
add_compile_definitions(GGML_USE_HPX)
|
||||||
|
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${HPX_CXXFLAGS})
|
||||||
|
else()
|
||||||
|
message(FATAL "HPX not found")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
if (LLAMA_CLBLAST)
|
if (LLAMA_CLBLAST)
|
||||||
find_package(CLBlast)
|
find_package(CLBlast)
|
||||||
if (CLBlast_FOUND)
|
if (CLBlast_FOUND)
|
||||||
|
@ -727,7 +743,11 @@ add_library(ggml OBJECT
|
||||||
|
|
||||||
target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})
|
target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})
|
||||||
target_compile_features(ggml PUBLIC c_std_11) # don't bump
|
target_compile_features(ggml PUBLIC c_std_11) # don't bump
|
||||||
target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
if(LLAMA_HPX AND HPX_FOUND)
|
||||||
|
target_link_libraries(ggml PUBLIC HPX::hpx ${LLAMA_EXTRA_LIBS})
|
||||||
|
else()
|
||||||
|
target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
||||||
|
endif()
|
||||||
if (GGML_USE_CPU_HBM)
|
if (GGML_USE_CPU_HBM)
|
||||||
target_link_libraries(ggml PUBLIC memkind)
|
target_link_libraries(ggml PUBLIC memkind)
|
||||||
endif()
|
endif()
|
||||||
|
@ -749,6 +769,12 @@ add_library(llama
|
||||||
|
|
||||||
target_include_directories(llama PUBLIC .)
|
target_include_directories(llama PUBLIC .)
|
||||||
target_compile_features(llama PUBLIC cxx_std_11) # don't bump
|
target_compile_features(llama PUBLIC cxx_std_11) # don't bump
|
||||||
|
|
||||||
|
if(LLAMA_HPX AND HPX_FOUND)
|
||||||
|
target_link_libraries(llama PUBLIC HPX::hpx ${LLAMA_EXTRA_LIBS})
|
||||||
|
target_compile_options (llama PRIVATE ${HPX_CXXFLAGS})
|
||||||
|
endif()
|
||||||
|
|
||||||
target_link_libraries(llama PRIVATE
|
target_link_libraries(llama PRIVATE
|
||||||
ggml
|
ggml
|
||||||
${LLAMA_EXTRA_LIBS}
|
${LLAMA_EXTRA_LIBS}
|
||||||
|
|
44
Makefile
44
Makefile
|
@ -103,7 +103,11 @@ endif
|
||||||
# keep standard at C11 and C++11
|
# keep standard at C11 and C++11
|
||||||
MK_CPPFLAGS = -I. -Icommon
|
MK_CPPFLAGS = -I. -Icommon
|
||||||
MK_CFLAGS = -std=c11 -fPIC
|
MK_CFLAGS = -std=c11 -fPIC
|
||||||
|
ifdef LLAMA_HPX
|
||||||
|
MK_CXXFLAGS = -std=c++17 -fPIC
|
||||||
|
else
|
||||||
MK_CXXFLAGS = -std=c++11 -fPIC
|
MK_CXXFLAGS = -std=c++11 -fPIC
|
||||||
|
endif
|
||||||
|
|
||||||
# -Ofast tends to produce faster code, but may not be available for some compilers.
|
# -Ofast tends to produce faster code, but may not be available for some compilers.
|
||||||
ifdef LLAMA_FAST
|
ifdef LLAMA_FAST
|
||||||
|
@ -345,6 +349,46 @@ ifdef LLAMA_MPI
|
||||||
OBJS += ggml-mpi.o
|
OBJS += ggml-mpi.o
|
||||||
endif # LLAMA_MPI
|
endif # LLAMA_MPI
|
||||||
|
|
||||||
|
ifdef LLAMA_HPX
|
||||||
|
ifndef HWLOC_FOUND
|
||||||
|
HWLOC_PKG:=hwloc
|
||||||
|
HWLOC_REQPKG:=$(shell pkg-config --exists $(HWLOC_PKG) && echo '$(HWLOC_PKG)')
|
||||||
|
ifneq ($(HWLOC_REQPKG),)
|
||||||
|
HWLOC_FOUND:=1
|
||||||
|
HWLOC_CXXFLAGS:=$(shell pkg-config --cflags $(HWLOC_PKG))
|
||||||
|
HWLOC_LDFLAGS:=$(shell pkg-config --libs $(HWLOC_PKG))
|
||||||
|
warn := $(warning hwloc found)
|
||||||
|
else
|
||||||
|
$(warning 'hwloc' not found)
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifndef HWLOC_FOUND
|
||||||
|
$(error hwloc not found)
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifndef HPX_FOUND
|
||||||
|
HPX_PKG:=hpx_component
|
||||||
|
HPX_REQPKG:=$(shell pkg-config --exists $(HPX_PKG) && echo '$(HPX_PKG)')
|
||||||
|
ifneq ($(HPX_REQPKG),)
|
||||||
|
HPX_FOUND:=1
|
||||||
|
HPX_CXXFLAGS:=$(shell pkg-config --cflags hpx_component)
|
||||||
|
HPX_LDFLAGS:=$(shell pkg-config --libs hpx_component)
|
||||||
|
warn := $(warning HPX found)
|
||||||
|
else
|
||||||
|
$(warning 'HPX' not found)
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifndef HPX_FOUND
|
||||||
|
$(error HPX not found)
|
||||||
|
endif
|
||||||
|
|
||||||
|
MK_CPPFLAGS += -DGGML_USE_HPX $(HWLOC_CXXFLAGS) $(HPX_CXXFLAGS)
|
||||||
|
MK_CXXFLAGS += -Wno-cast-qual $(HWLOC_CXXFLAGS) $(HPX_CXXFLAGS)
|
||||||
|
MK_LDFLAGS += -Wno-cast-qual $(HWLOC_LDFLAGS) $(HPX_LDFLAGS)
|
||||||
|
endif # LLAMA_HPX
|
||||||
|
|
||||||
ifdef LLAMA_OPENBLAS
|
ifdef LLAMA_OPENBLAS
|
||||||
MK_CPPFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags-only-I openblas)
|
MK_CPPFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags-only-I openblas)
|
||||||
MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas)
|
MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas)
|
||||||
|
|
108
llama.cpp
108
llama.cpp
|
@ -19,6 +19,12 @@
|
||||||
#ifdef GGML_USE_MPI
|
#ifdef GGML_USE_MPI
|
||||||
# include "ggml-mpi.h"
|
# include "ggml-mpi.h"
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef GGML_USE_HPX
|
||||||
|
# include <cstdlib>
|
||||||
|
# include <hpx/hpx_start.hpp>
|
||||||
|
# include <hpx/runtime_local/run_as_hpx_thread.hpp>
|
||||||
|
# include <hpx/execution.hpp>
|
||||||
|
#endif
|
||||||
#ifndef QK_K
|
#ifndef QK_K
|
||||||
# ifdef GGML_QKK_64
|
# ifdef GGML_QKK_64
|
||||||
# define QK_K 64
|
# define QK_K 64
|
||||||
|
@ -8419,6 +8425,81 @@ struct quantize_state_internal {
|
||||||
{}
|
{}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HPX)
|
||||||
|
|
||||||
|
static void llama_convert_tensor_internal(
|
||||||
|
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<hpx::future<void>> & futures,
|
||||||
|
const size_t nelements, const int nthread
|
||||||
|
) {
|
||||||
|
if (output.size() < nelements) {
|
||||||
|
output.resize(nelements);
|
||||||
|
}
|
||||||
|
float * f32_output = (float *) output.data();
|
||||||
|
|
||||||
|
ggml_type_traits_t qtype;
|
||||||
|
if (ggml_is_quantized(tensor->type)) {
|
||||||
|
qtype = ggml_internal_get_type_traits(tensor->type);
|
||||||
|
if (qtype.to_float == NULL) {
|
||||||
|
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
|
||||||
|
}
|
||||||
|
} else if (tensor->type != GGML_TYPE_F16) {
|
||||||
|
throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nthread < 2) {
|
||||||
|
if (tensor->type == GGML_TYPE_F16) {
|
||||||
|
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
|
||||||
|
} else if (ggml_is_quantized(tensor->type)) {
|
||||||
|
qtype.to_float(tensor->data, f32_output, nelements);
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false); // unreachable
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type);
|
||||||
|
size_t block_size_bytes = ggml_type_size(tensor->type);
|
||||||
|
|
||||||
|
GGML_ASSERT(nelements % block_size == 0);
|
||||||
|
size_t nblocks = nelements / block_size;
|
||||||
|
size_t blocks_per_thread = nblocks / nthread;
|
||||||
|
size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
|
||||||
|
|
||||||
|
size_t in_buff_offs = 0;
|
||||||
|
size_t out_buff_offs = 0;
|
||||||
|
|
||||||
|
hpx::future<void> fut =
|
||||||
|
hpx::run_as_hpx_thread([&futures, nthread, qtype, block_size, block_size_bytes, blocks_per_thread, spare_blocks, &tensor, &in_buff_offs, &f32_output, &out_buff_offs]() -> hpx::future<void>
|
||||||
|
{
|
||||||
|
for (int tnum = 0; tnum < nthread; tnum++) {
|
||||||
|
size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
|
||||||
|
size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
|
||||||
|
size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
|
||||||
|
|
||||||
|
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
|
||||||
|
if (typ == GGML_TYPE_F16) {
|
||||||
|
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
|
||||||
|
} else {
|
||||||
|
qtype.to_float(inbuf, outbuf, nels);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
futures.push_back(hpx::async(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems));
|
||||||
|
|
||||||
|
in_buff_offs += thr_block_bytes;
|
||||||
|
out_buff_offs += thr_elems;
|
||||||
|
}
|
||||||
|
|
||||||
|
hpx::wait_all(futures);
|
||||||
|
return hpx::make_ready_future<void>();
|
||||||
|
});
|
||||||
|
|
||||||
|
fut.wait();
|
||||||
|
futures.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
static void llama_convert_tensor_internal(
|
static void llama_convert_tensor_internal(
|
||||||
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
||||||
const size_t nelements, const int nthread
|
const size_t nelements, const int nthread
|
||||||
|
@ -8480,6 +8561,8 @@ static void llama_convert_tensor_internal(
|
||||||
workers.clear();
|
workers.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
|
static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
|
||||||
const std::string name = ggml_get_name(tensor);
|
const std::string name = ggml_get_name(tensor);
|
||||||
|
|
||||||
|
@ -8687,8 +8770,20 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
size_t total_size_new = 0;
|
size_t total_size_new = 0;
|
||||||
std::vector<int64_t> hist_all(1 << 4, 0);
|
std::vector<int64_t> hist_all(1 << 4, 0);
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HPX)
|
||||||
|
{
|
||||||
|
std::string thread_arg = "--hpx:threads=" + std::to_string(nthread);
|
||||||
|
hpx::init_params params;
|
||||||
|
params.cfg = { thread_arg };
|
||||||
|
hpx::start(nullptr, 0, nullptr, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<hpx::future<void>> futures;
|
||||||
|
futures.reserve(nthread);
|
||||||
|
#else
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
workers.reserve(nthread);
|
workers.reserve(nthread);
|
||||||
|
#endif
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
|
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
|
@ -8772,7 +8867,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
} else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
|
} else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
|
||||||
throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
|
throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
|
||||||
} else {
|
} else {
|
||||||
|
#if defined(GGML_USE_HPX)
|
||||||
|
llama_convert_tensor_internal(tensor, f32_conv_buf, futures, nelements, nthread);
|
||||||
|
#else
|
||||||
llama_convert_tensor_internal(tensor, f32_conv_buf, workers, nelements, nthread);
|
llama_convert_tensor_internal(tensor, f32_conv_buf, workers, nelements, nthread);
|
||||||
|
#endif
|
||||||
f32_data = (float *) f32_conv_buf.data();
|
f32_data = (float *) f32_conv_buf.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8813,12 +8912,21 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data());
|
local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#if defined(GGML_USE_HPX)
|
||||||
|
for (int it = 0; it < nthread_use - 1; ++it) {
|
||||||
|
futures.push_back(hpx::async(compute));
|
||||||
|
}
|
||||||
|
compute();
|
||||||
|
hpx::wait_all(futures);
|
||||||
|
futures.clear();
|
||||||
|
#else
|
||||||
for (int it = 0; it < nthread_use - 1; ++it) {
|
for (int it = 0; it < nthread_use - 1; ++it) {
|
||||||
workers.emplace_back(compute);
|
workers.emplace_back(compute);
|
||||||
}
|
}
|
||||||
compute();
|
compute();
|
||||||
for (auto & w : workers) { w.join(); }
|
for (auto & w : workers) { w.join(); }
|
||||||
workers.clear();
|
workers.clear();
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
|
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue