diff --git a/CMakeLists.txt b/CMakeLists.txt index 7bd640966..cd9be0c11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,6 +101,7 @@ option(LLAMA_METAL "llama: use Metal" option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) option(LLAMA_METAL_SHADER_DEBUG "llama: compile Metal with -fno-fast-math" OFF) option(LLAMA_MPI "llama: use MPI" OFF) +option(LLAMA_OPENSHMEM "llama: use OpenSHMEM" OFF) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) @@ -385,6 +386,29 @@ if (LLAMA_MPI) endif() endif() +if (LLAMA_OPENSHMEM) + cmake_minimum_required(VERSION 3.10) + include(cmake/FindOpenSHMEM.cmake) + + setup_openshmem() + + if (OPENSHMEM_FOUND) + message(STATUS "OpenSHMEM found") + set(GGML_HEADERS_OPENSHMEM ggml-oshmem.h) + set(GGML_SOURCES_OPENSHMEM ggml-oshmem.c ggml-oshmem.h) + add_compile_definitions(GGML_USE_OPENSHMEM) + + if (NOT MSVC) + add_compile_options(-Wno-cast-qual) + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${OPENSHMEM_LDFLAGS}) + string(REPLACE "-I" "" OPENSHMEM_CFLAGS ${OPENSHMEM_CFLAGS}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${OPENSHMEM_CFLAGS}) + else() + message(WARNING "OpenSHMEM not found") + endif() +endif() + if (LLAMA_CLBLAST) find_package(CLBlast) if (CLBlast_FOUND) @@ -770,6 +794,7 @@ add_library(ggml OBJECT ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} + ${GGML_SOURCES_OPENSHMEM} ${GGML_HEADERS_OPENSHMEM} ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} ) diff --git a/Makefile b/Makefile index 995b89f7a..25c11b0f4 100644 --- a/Makefile +++ b/Makefile @@ -350,6 +350,56 @@ ifdef LLAMA_MPI OBJS += ggml-mpi.o endif # LLAMA_MPI +ifdef LLAMA_OPENSHMEM + ifndef OPENSHMEM_FOUND + OSHMEM_PKG:=sandia-openshmem + OSHMEM_REQPKG:=$(shell pkg-config --exists $(OSHMEM_PKG) && echo '$(OSHMEM_PKG)') + ifneq ($(OSHMEM_REQPKG),) + OPENSHMEM_FOUND:=1 + OPENSHMEM_CFLAGS:=$(shell pkg-config --cflags sandia-openshmem) + OPENSHMEM_LDFLAGS:=$(shell pkg-config --libs sandia-openshmem) + warn := $(warning OpenSHMEM found) + else + $(warning '$(OSHMEM_PKG)' not found) + endif + endif + + ifndef OPENSHMEM_FOUND + OSHMEM_PKG:=osss-ucx + OSHMEM_REQPKG:=$(shell pkg-config --exists $(OSHMEM_PKG) && echo '$(OSHMEM_PKG)') + ifneq ($(OSHMEM_REQPKG),) + OPENSHMEM_FOUND:=1 + OPENSHMEM_CFLAGS:=$(shell pkg-config --cflags osss-ucx) + OPENSHMEM_LDFLAGS:=$(shell pkg-config --libs osss-ucx) + warn := $(warning OpenSHMEM found) + else + $(warning '$(OSHMEM_PKG)' not found) + endif + endif + + ifndef OPENSHMEM_FOUND + OSHMEM_PKG:=oshmem + OSHMEM_REQPKG:=$(shell pkg-config --exists $(OSHMEM_PKG) && echo '$(OSHMEM_PKG)') + ifneq ($(OSHMEM_REQPKG),) + OPENSHMEM_FOUND:=1 + OPENSHMEM_CFLAGS:=$(shell oshmem_info --path libdir) + OPENSHMEM_LDFLAGS:=$(shell oshmem_info --path incdir) + warn := $(warning OpenSHMEM found) + else + $(warning '$(OSHMEM_PKG)' not found) + endif + endif + + ifndef OPENSHMEM_FOUND + $(error OpenSHMEM not found) + endif + + MK_CPPFLAGS += -DGGML_USE_OPENSHMEM $(OPENSHMEM_CFLAGS) + MK_CFLAGS += -Wno-cast-qual $(OPENSHMEM_CFLAGS) + MK_LDFLAGS += -Wno-cast-qual $(OPENSHMEM_LDFLAGS) + OBJS += ggml-oshmem.o +endif # LLAMA_OPENSHMEM + ifdef LLAMA_OPENBLAS MK_CPPFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags-only-I openblas) MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas) diff --git a/README.md b/README.md index 866aa87b4..553ad91d2 100644 --- a/README.md +++ b/README.md @@ -342,6 +342,51 @@ Finally, you're ready to run a computation using `mpirun`: mpirun -hostfile hostfile -n 3 ./main -m ./models/7B/ggml-model-q4_0.gguf -n 128 ``` +### OpenSHMEM Build + +OpenSHMEM lets you distribute a computation over a cluster of machines using a Partitioned Global Address Space (PGAS). OpenSHMEM's cluster abstraction is the Parallel-Random-Access-Machine (PRAM). OpenSHMEM's status as a PRAM abstraction means applications are written using the Single-Program-Many-Data (SPMD) style. OpenSHMEM is a shared memory machine abstraction for a cluster. The shared-memory machine abstraction means distributed communications operate like memory copies (memcpy). The receiver does not get a "notification" that communication events have occurred. Senders and recievers can "put" and "get" to remote memory at will. OpenSHMEM is a single-sided communication model that tends to yield improved performance for certain applications. The caveat to that statement is the underlying hardware and software layers. OpenSHMEM operates best when the communication protocol is "fire and forget" (similar to UDP). OpenSHMEM operates best on systems with remote-direct-memory-access (RDMA) enabled network-interface-cards (NICs). OpenSHMEM can work over a commodity ethernet cluster. OpenSHMEM can work on a single machine using a shared memory backend. llama.cpp's OpenSHMEM backend is designed for cluster environments. LLM inference is an inherently serial process. Using OpenSHMEM will not yield any significant [strong scaling](https://hpc-wiki.info/hpc/Scaling#Strong_or_Weak_Scaling) effects. OpenSHMEM it will let you run larger models (over a cluster) than would otherwise fit into the memory (RAM) of a single machine. + +First you will need the OpenSHMEM libraries installed on your system. There are 3 options: [OpenMPI's OpenSHMEM](https://www.open-mpi.org), [OSSS-OpenSHMEM](https://github.com/openshmem-org/osss-ucx) and [Sandia-OpenSHMEM](https://github.com/Sandia-OpenSHMEM/SOS). OSSS-OpenSHMEM has a dependency on the [UCX](https://github.com/openucx/ucx) communication library. Sandia-OpenSHMEM can run over udp, [UCX](https://github.com/openucx/ucx), or [libfabric](https://github.com/ofiwg/libfabric). OpenMPI's OpenSHMEM can be installed with a package manager (apt, homebrew, etc). UCX, OSSS-OpenSHMEM, and Sandia-OpenSHMEM can all be installed from source. + +Next you will need to build the project with `LLAMA_OPENSHMEM` set to true on all machines; if you're building with `make`, you will also need to specify an OpenSHMEM-capable compiler (when building with CMake, this is configured automatically): + +- Using `make`: + + ```bash + make CC=oshcc CXX=oshc++ LLAMA_OPENSHMEM=1 + ``` + +- Using `CMake`: + + ```bash + cmake -S . -B build -DCMAKE_C_COMPILER=oshcc -DCMAKE_CXX_COMPILER=oshc++ -DLLAMA_OPENSHMEM=ON + ``` + +It's strongly encouraged that users exercise this backend over a cluster that is configured to operate like a parallel machine. This means users should consider installing and configuring a distributed file system (NFS). Users are also encouraged to install a bulk-synchronous scheduler (ie: (Slurm)[https://slurm.schedmd.com]). Typical parallel machine configurations usually have 2 networks, a network for slurm/NFS and a seperate network for compute. This may not be practical for most users. After compiling llama.cpp w/OpenSHMEM, users will just need to copy the programs and weights onto the distributed file system. In order to run llama.cpp w/OpenSHMEM a user will need to run the program from the distributed file system using a bulk-synchronous scheduler. The following example assumes a slurm cluster is setup and configured. The example asserts an NFS installation is setup, configured, and mounted on each machine with the following path: `/nfs_path`. + +``` +srun -n 2 /nfs_path/main -m /nfs_path/models/7B/ggml-model-q4_0.gguf -n 128 +``` + +If you do not have access to a cluster with a bulk-synchronous scheduler or a distributed file system, the following instructions will help you stage an installation and run the application. Build the programs, download/convert the weights on all of the machines in your cluster. The paths to the weights and programs should be identical on all machines. + +Next, ensure password-less SSH access to each machine from the primary host, and create a `hostfile` with a list of the hostnames and their relative "weights" (slots). If you want to use localhost for computation, use its local subnet IP address rather than the loopback address or "localhost". + +Here is an example hostfile: + +``` +192.168.0.1:1 +malvolio.local:1 +``` + +The above will distribute the computation across 1 processes on the first host and 1 process on the second host. Each process will use roughly an equal amount of RAM. Try to keep these numbers small, as inter-process (intra-host) communication is expensive. It is a requirement of OpenSHMEM that the distributed job be performed over a number of machines that is equal to a power of 2. + +Finally, you're ready to run a computation using `mpirun`: + +```bash +oshrun -hostfile hostfile -n 2 ./main -m ./models/7B/ggml-model-q4_0.gguf -n 128 +``` + ### BLAS Build Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). Support with CPU-only BLAS implementations doesn't affect the normal generation performance. We may see generation performance improvements with GPU-involved BLAS implementations, e.g. cuBLAS, hipBLAS and CLBlast. There are currently several different BLAS implementations available for build and use: diff --git a/cmake/FindOpenSHMEM.cmake b/cmake/FindOpenSHMEM.cmake new file mode 100644 index 000000000..550d7639f --- /dev/null +++ b/cmake/FindOpenSHMEM.cmake @@ -0,0 +1,917 @@ +# Copyright (c) 2019-2023 Ste||ar Group +# +# SPDX-License-Identifier: BSL-1.0 +# Distributed under the Boost Software License, Version 1.0. (See accompanying +# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +# +macro(setup_openshmem) + + if(NOT TARGET PkgConfig::OPENSHMEM) + + set(OPENSHMEM_PC "") + + find_package(MPI) + if (LLAMA_MPI AND MPI_C_FOUND) + set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:${MPI_LIBDIR}/pkgconfig") + + set(OPENSHMEM_PC "oshmem") + pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC}) + + if(NOT OPENSHMEM_FOUND) + find_program(OSHMEM_INFO NAMES oshmem_info ompi_info REQUIRED) + + if(NOT OSHMEM_INFO) + message( + FATAL_ERROR + "oshmem_info and/or ompi_info not found! pkg-config cannot find OpenMPI's `${OPENSHMEM_PC}.pc`" + ) + endif() + + set(OSHMEM_INFO_OUTPUT + "${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_stdout.log" + ) + set(OSHMEM_INFO_ERROR + "${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_error.log" + ) + + execute_process( + COMMAND bash -c "${OSHMEM_INFO} --path libdir" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE OSHMEM_INFO_STATUS + OUTPUT_FILE ${OSHMEM_INFO_OUTPUT} + ERROR_FILE ${OSHMEM_INFO_ERROR} + ) + + if(OSHMEM_INFO_STATUS) + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Program status code: ${OSHMEM_INFO_STATUS}" + ) + endif() + + file(READ ${OSHMEM_INFO_OUTPUT} OSHMEM_INFO_OUTPUT_CONTENT) + + if(NOT DEFINED OSHMEM_INFO_OUTPUT_CONTENT) + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_ERROR}\n${OSHMEM_INFO_OUTPUT_CONTENT}" + ) + endif() + + if("${OSHMEM_INFO_OUTPUT_CONTENT}" STREQUAL "") + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_ERROR}\n${OSHMEM_INFO_OUTPUT_CONTENT}" + ) + endif() + + string(REGEX MATCH "(\/.*)" OSHMEM_LIBDIR_PATH + ${OSHMEM_INFO_OUTPUT_CONTENT} + ) + + string(STRIP ${OSHMEM_LIBDIR_PATH} OSHMEM_LIBDIR_PATH) + + set(ENV{PKG_CONFIG_PATH} + "$ENV{PKG_CONFIG_PATH}:${OSHMEM_LIBDIR_PATH}/pkgconfig" + ) + + pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC}) + + if(NOT OPENSHMEM_FOUND) + + set(OSHMEM_INFO_INCOUTPUT + "${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_stdout_inc.log" + ) + set(OSHMEM_INFO_INCERROR + "${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_error_inc.log" + ) + + execute_process( + COMMAND bash -c "${OSHMEM_INFO} --path incdir" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE OSHMEM_INFO_INCSTATUS + OUTPUT_FILE ${OSHMEM_INFO_INCOUTPUT} + ERROR_FILE ${OSHMEM_INFO_INCERROR} + ) + + if(OSHMEM_INFO_INCSTATUS) + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Program status code: ${OSHMEM_INFO_INCSTATUS}" + ) + endif() + file(READ ${OSHMEM_INFO_INCOUTPUT} OSHMEM_INFO_OUTPUT_INCCONTENT) + + if(NOT DEFINED OSHMEM_INFO_OUTPUT_INCCONTENT) + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_INCERROR}" + ) + endif() + + if("${OSHMEM_INFO_OUTPUT_INCCONTENT}" STREQUAL "") + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_INCERROR}\n${OSHMEM_INFO_OUTPUT_INCCONTENT}" + ) + endif() + + string(REGEX MATCH "(\/.*)" OSHMEM_INCDIR_PATH + ${OSHMEM_INFO_OUTPUT_INCCONTENT} + ) + + string(STRIP ${OSHMEM_INCDIR_PATH} OSHMEM_INCDIR_PATH) + + set(OPENSHMEM_CFLAGS + "-I${OSHMEM_INCDIR_PATH} -pthread -I${OSHMEM_LIBDIR_PATH}" + ) + set(OPENSHMEM_LDFLAGS "-loshmem") + set(OPENSHMEM_LIBRARY_DIRS "${OSHMEM_LIBDIR_PATH}") + + add_library(PkgConfig::OPENSHMEM INTERFACE IMPORTED GLOBAL) + + set(OPENSHMEM_FOUND ON) + endif() + endif() + else() + + include(cmake/FindOpenShmemPmi.cmake) + + set(PMI_AUTOCONF_OPTS "") + if(NOT PMI_LIBRARY OR NOT PMI_FOUND) + set(PMI_AUTOCONF_OPTS "--enable-pmi-simple") + else() + string(REGEX MATCH "(^\/[^\/]+)" PMI_INCLUDE_DIR_ROOT_PATH + ${PMI_INCLUDE_DIR} + ) + string(REGEX MATCH "(^\/[^\/]+)" PMI_LIBRARY_ROOT_PATH ${PMI_LIBRARY}) + set(PMI_AUTOCONF_OPTS + "--with-pmi=${PMI_INCLUDE_DIR_ROOT_PATH} --with-pmi-libdir=${PMI_LIBRARY_ROOT_PATH}" + ) + endif() + + set(OPENSHMEM_PC "osss-ucx") + + pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC}) + if(NOT OPENSHMEM_FOUND) + set(OPENSHMEM_PC "sandia-openshmem") + pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC}) + endif() + endif() + endif() + + if(OPENSHMEM_CFLAGS) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(FLAG_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_CFLAGS}) + string(FIND "${X}" "--param" PARAM_FOUND) + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + endif() + endforeach() + + list(LENGTH OPENSHMEM_CFLAGS IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_CFLAGS NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_CFLAGS "${X}") + endforeach() + endif() + + if(OPENSHMEM_CFLAGS_OTHER) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(FLAG_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_CFLAGS_OTHER}) + string(FIND "${X}" "--param" PARAM_FOUND) + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + endif() + endforeach() + + list(LENGTH OPENSHMEM_CFLAGS_OTHER IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_CFLAGS_OTHER NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_CFLAGS_OTHER "${X}") + endforeach() + endif() + + if(OPENSHMEM_LDFLAGS) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(DIRIDX 0) + set(SKIP 0) + set(FLAG_LIST "") + set(DIR_LIST "") + set(LIB_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_LDFLAGS}) + string(FIND "${X}" "--param" PARAM_FOUND) + string(FIND "${X}" "-lsma" IDX) + string(FIND "${X}" "-l" LIDX) + string(FIND "${X}" "-L" DIRIDX) + string(FIND "${X}" "-Wl" SKIP) + + if("${SKIP}" EQUAL "-1") + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IDX}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-l" "" TMPSTR "${X}") + list(APPEND LIB_LIST "${TMPSTR}") + set(IDX 0) + elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + list(APPEND FLAG_LIST "${X}") + endif() + if(NOT "${DIRIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-L" "" TMPSTR "${X}") + list(APPEND DIR_LIST "${TMPSTR}") + endif() + endif() + endforeach() + + set(IDX 0) + list(LENGTH LIB_LIST IDX) + + if(NOT "${IDX}" EQUAL "0") + set(IDX 0) + + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(NEWLINK "SHELL:-Wl,--whole-archive + " + ) + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(APPEND NEWLINK " + -Wl,--no-whole-archive" + ) + string(FIND "SHELL:-Wl,--whole-archive + -Wl,--no-whole-archive" "${NEWLINK}" IDX + ) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + if(APPLE) + set(NEWLINK "SHELL:-Wl,-force_load,") + else() + set(NEWLINK "SHELL: + " + ) + endif() + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(FIND "SHELL:" "${NEWLINK}" IDX) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + endif() + endif() + endif() + + if(OPENSHMEM_LDFLAGS_OTHER) + unset(FOUND_LIB) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(SKIP 0) + set(IDX 0) + set(DIRIDX 0) + set(FLAG_LIST "") + set(DIR_LIST "") + set(LIB_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_LDFLAGS_OTHER}) + string(FIND "${X}" "--param" PARAM_FOUND) + string(FIND "${X}" "-lsma" IDX) + string(FIND "${X}" "-L" DIRIDX) + string(FIND "${X}" "-Wl" SKIP) + + if("${SKIP}" EQUAL "-1") + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IDX}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-l" "" TMPSTR "${X}") + list(APPEND LIB_LIST "${TMPSTR}") + set(IDX 0) + elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + list(APPEND FLAG_LIST "${X}") + endif() + if(NOT "${DIRIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-L" "" TMPSTR "${X}") + list(APPEND DIR_LIST "${TMPSTR}") + endif() + endif() + endforeach() + + set(IDX 0) + list(LENGTH OPENSHMEM_LDFLAGS_OTHER IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_LDFLAGS_OTHER NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_LDFLAGS_OTHER "${X}") + endforeach() + + set(IDX 0) + list(LENGTH LIB_LIST IDX) + if(NOT "${IDX}" EQUAL "0") + set(IDX 0) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(NEWLINK "SHELL:-Wl,--whole-archive + " + ) + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(APPEND NEWLINK " + -Wl,--no-whole-archive" + ) + + string(FIND "SHELL:-Wl,--whole-archive + -Wl,--no-whole-archive" "${NEWLINK}" IDX + ) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS_OTHER "${NEWLINK}") + endif() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + if(APPLE) + set(NEWLINK "SHELL:-Wl,-force_load,") + else() + set(NEWLINK "SHELL: + " + ) + endif() + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(FIND "SHELL:" "${NEWLINK}" IDX) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + endif() + endif() + endif() + + if(OPENSHMEM_STATIC_CFLAGS) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(FLAG_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_STATIC_CFLAGS}) + string(FIND "${X}" "--param" PARAM_FOUND) + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + endif() + endforeach() + + list(LENGTH OPENSHMEM_STATIC_CFLAGS IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_STATIC_CFLAGS NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_STATIC_CFLAGS "${X}") + endforeach() + endif() + + if(OPENSHMEM_STATIC_CFLAGS_OTHER) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(FLAG_LIST "") + foreach(X IN ITEMS ${OPENSHMEM_STATIC_CFLAGS_OTHER}) + string(FIND "${X}" "--param" PARAM_FOUND) + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + endif() + endforeach() + + list(LENGTH OPENSHMEM_STATIC_CFLAGS_OTHER IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_STATIC_CFLAGS_OTHER NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_STATIC_CFLAGS_OTHER "${X}") + endforeach() + endif() + + if(OPENSHMEM_STATIC_LDFLAGS) + unset(FOUND_LIB) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(SKIP 0) + set(IDX 0) + set(DIRIDX 0) + set(FLAG_LIST "") + set(DIR_LIST "") + set(LIB_LIST "") + foreach(X IN ITEMS ${OPENSHMEM_STATIC_LDFLAGS}) + string(FIND "${X}" "--param" PARAM_FOUND) + if("${HPX_WITH_PARCELPORT_OPENSHMEM_CONDUIT}" STREQUAL "mpi") + string(FIND "${X}" "-loshmem" IDX) + else() + string(FIND "${X}" "-lsma" IDX) + endif() + string(FIND "${X}" "-L" DIRIDX) + string(FIND "${X}" "-Wl" SKIP) + + if("${SKIP}" EQUAL "-1") + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IDX}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-l" "" TMPSTR "${X}") + list(APPEND LIB_LIST "${TMPSTR}") + set(IDX 0) + elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + list(APPEND FLAG_LIST "${X}") + endif() + if(NOT "${DIRIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-L" "" TMPSTR "${X}") + list(APPEND DIR_LIST "${TMPSTR}") + endif() + endif() + endforeach() + set(IDX 0) + list(LENGTH OPENSHMEM_STATIC_LDFLAGS IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_STATIC_LDFLAGS NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_STATIC_LDFLAGS "${X}") + endforeach() + + set(IDX 0) + list(LENGTH LIB_LIST IDX) + if(NOT "${IDX}" EQUAL "0") + set(IDX 0) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(NEWLINK "SHELL:-Wl,--whole-archive + " + ) + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(APPEND NEWLINK " + -Wl,--no-whole-archive" + ) + + string(FIND "SHELL:-Wl,--whole-archive + -Wl,--no-whole-archive" "${NEWLINK}" IDX + ) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_STATIC_LDFLAGS "${NEWLINK}") + endif() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + if(APPLE) + set(NEWLINK "SHELL:-Wl,-force_load,") + else() + set(NEWLINK "SHELL: + " + ) + endif() + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(FIND "SHELL:" "${NEWLINK}" IDX) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + endif() + endif() + endif() + + if(OPENSHMEM_STATIC_LDFLAGS_OTHER) + unset(FOUND_LIB) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(SKIP 0) + set(IDX 0) + set(DIRIDX 0) + set(FLAG_LIST "") + set(DIR_LIST "") + set(LIB_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_STATIC_LDFLAGS_OTHER}) + string(FIND "${X}" "--param" PARAM_FOUND) + if("${HPX_WITH_PARCELPORT_OPENSHMEM_CONDUIT}" STREQUAL "mpi") + string(FIND "${X}" "-loshmem" IDX) + else() + string(FIND "${X}" "-lsma" IDX) + endif() + string(FIND "${X}" "-L" DIRIDX) + string(FIND "${X}" "-Wl" SKIP) + + if("${SKIP}" EQUAL "-1") + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IDX}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-l" "" TMPSTR "${X}") + list(APPEND LIB_LIST "${TMPSTR}") + set(IDX 0) + elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + list(APPEND FLAG_LIST "${X}") + endif() + if(NOT "${DIRIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-L" "" TMPSTR "${X}") + list(APPEND DIR_LIST "${TMPSTR}") + endif() + endif() + endforeach() + + set(IDX 0) + list(LENGTH OPENSHMEM_STATIC_LDFLAGS_OTHER IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_STATIC_LDFLAGS_OTHER NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_STATIC_LDFLAGS_OTHER "${X}") + endforeach() + + set(IDX 0) + list(LENGTH LIB_LIST IDX) + if(NOT "${IDX}" EQUAL "0") + set(IDX 0) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(NEWLINK "SHELL:-Wl,--whole-archive + " + ) + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + + message(STATUS "${FOUND_LIB} + ${X}" + ) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(APPEND NEWLINK " + -Wl,--no-whole-archive" + ) + string(FIND "SHELL:-Wl,--whole-archive + -Wl,--no-whole-archive" "${NEWLINK}" IDX + ) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_STATIC_LDFLAGS_OTHER "${NEWLINK}") + endif() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + if(APPLE) + set(NEWLINK "SHELL:-Wl,-force_load,") + else() + set(NEWLINK "SHELL: + " + ) + endif() + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(FIND "SHELL:" "${NEWLINK}" IDX) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + endif() + endif() + endif() + + if(OPENSHMEM_DIR) + list(TRANSFORM OPENSHMEM_CFLAGS + REPLACE "${OPENSHMEM_DIR}/install" + "$" + ) + list(TRANSFORM OPENSHMEM_LDFLAGS + REPLACE "${OPENSHMEM_DIR}/install" + "$" + ) + list(TRANSFORM OPENSHMEM_LIBRARY_DIRS + REPLACE "${OPENSHMEM_DIR}/install" + "$" + ) + + message(STATUS "OPENSHMEM_CFLAGS:\t${OPENSHMEM_CFLAGS}") + message(STATUS "OPENSHMEM_LDFLAGS:\t${OPENSHMEM_LDFLAGS}") + message(STATUS "OPENSHMEM_LIBRARY_DIRS:\t${OPENSHMEM_LIBRARY_DIRS}") + + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_COMPILE_OPTIONS + "${OPENSHMEM_CFLAGS}" + ) + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_OPTIONS + "${OPENSHMEM_LDFLAGS}" + ) + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_DIRECTORIES + "${OPENSHMEM_LIBRARY_DIRS}" + ) + set(OPENSHMEM_FOUND ON) + else() + message(STATUS "OPENSHMEM_CFLAGS:\t${OPENSHMEM_CFLAGS}") + message(STATUS "OPENSHMEM_LDFLAGS:\t${OPENSHMEM_LDFLAGS}") + message(STATUS "OPENSHMEM_LIBRARY_DIRS:\t${OPENSHMEM_LIBRARY_DIRS}") + + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_COMPILE_OPTIONS + "${OPENSHMEM_CFLAGS}" + ) + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_OPTIONS + "${OPENSHMEM_LDFLAGS}" + ) + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_DIRECTORIES + "${OPENSHMEM_LIBRARY_DIRS}" + ) + set(OPENSHMEM_FOUND ON) + endif() +endmacro() diff --git a/cmake/FindOpenShmemPmi.cmake b/cmake/FindOpenShmemPmi.cmake new file mode 100644 index 000000000..5f6814a50 --- /dev/null +++ b/cmake/FindOpenShmemPmi.cmake @@ -0,0 +1,65 @@ +# Copyright (c) 2023 Christopher Taylor +# +# SPDX-License-Identifier: BSL-1.0 +# Distributed under the Boost Software License, Version 1.0. (See accompanying +# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +find_package(PkgConfig QUIET) +# look for cray pmi... +pkg_check_modules(PC_PMI_CRAY QUIET cray-pmi) +# look for the rest if we couldn't find the cray package +if(NOT PC_PMI_CRAY_FOUND) + pkg_check_modules(PC_PMI QUIET pmi) +endif() + +find_path( + PMI_INCLUDE_DIR pmi2.h + HINTS ${PMI_ROOT} + ENV + PMI_ROOT + ${PMI_DIR} + ENV + PMI_DIR + ${PC_PMI_CRAY_INCLUDEDIR} + ${PC_PMI_CRAY_INCLUDE_DIRS} + ${PC_PMI_INCLUDEDIR} + ${PC_PMI_INCLUDE_DIRS} + PATH_SUFFIXES include +) + +find_library( + PMI_LIBRARY + NAMES pmi + HINTS ${PMI_ROOT} + ENV + PMI_ROOT + ${PC_PMI_CRAY_LIBDIR} + ${PC_PMI_CRAY_LIBRARY_DIRS} + ${PC_PMI_LIBDIR} + ${PC_PMI_LIBRARY_DIRS} + PATH_SUFFIXES lib lib64 +) + +# Set PMI_ROOT in case the other hints are used +if(PMI_ROOT) + # The call to file is for compatibility with windows paths + file(TO_CMAKE_PATH ${PMI_ROOT} PMI_ROOT) +elseif("$ENV{PMI_ROOT}") + file(TO_CMAKE_PATH $ENV{PMI_ROOT} PMI_ROOT) +else() + file(TO_CMAKE_PATH "${PMI_INCLUDE_DIR}" PMI_INCLUDE_DIR) + string(REPLACE "/include" "" PMI_ROOT "${PMI_INCLUDE_DIR}") +endif() + +if(NOT PMI_LIBRARY OR NOT PMI_INCLUDE_DIR) + set(PMI_FOUND=OFF) + return() +endif() + +# hpx_error( "PMI_LIBRARY OR PMI_INCLUDE_DIR not found, please install PMI or +# set \ the right PMI_ROOT path" ) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(PMI DEFAULT_MSG PMI_LIBRARY PMI_INCLUDE_DIR) + +mark_as_advanced(PMI_ROOT PMI_LIBRARY PMI_INCLUDE_DIR) diff --git a/ggml-oshmem.c b/ggml-oshmem.c new file mode 100644 index 000000000..1a1b85dea --- /dev/null +++ b/ggml-oshmem.c @@ -0,0 +1,398 @@ +#include "ggml-oshmem.h" + +#include "ggml.h" + +#include + +#include +#include +#include + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +#define UNUSED GGML_UNUSED + +#define OPENSHMEM_SYMMETRIC_BUFFER_SIZE 4096 + +struct ggml_openshmem_context { + int pe; + int n_pes; + int64_t symmetric_buffer_size; + int64_t symmetric_comm_structure_size; + uint8_t * symmetric_comm_structure; +}; + +void ggml_openshmem_backend_init(void) { + int provided = 0; + shmem_init_thread(SHMEM_THREAD_MULTIPLE, &provided); +} + +void ggml_openshmem_backend_free(void) { + shmem_finalize(); +} + +struct ggml_openshmem_context * ggml_openshmem_init(void) { + struct ggml_openshmem_context * ctx = + (struct ggml_openshmem_context *)calloc(1, sizeof(struct ggml_openshmem_context)); + + ctx->pe = shmem_my_pe(); + ctx->n_pes = shmem_n_pes(); + + /* + * makes a symmetric heap allocation on all processing elements (processes running this SPMD program) + * + * below is a struct representing the layout of the symmetric allocation: + * + * { + * int64_t offset_in_buffer, + * int64_t length_in_buffer, + * uint8_t buffer[shmem_npes()][OPENSHMEM_SYMMETRIC_BUFFER_SIZE] + * } + * + */ + ctx->symmetric_buffer_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE; + ctx->symmetric_comm_structure_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE + sizeof(int64_t) + sizeof(int64_t) + sizeof(uint64_t) + sizeof(uint64_t); + ctx->symmetric_comm_structure = (uint8_t*)shmem_calloc(1, ctx->n_pes*ctx->symmetric_comm_structure_size); + + return ctx; +} + +void ggml_openshmem_free(struct ggml_openshmem_context * ctx) { + shmem_free(ctx->symmetric_comm_structure); + free(ctx); +} + +int ggml_openshmem_pe(struct ggml_openshmem_context * ctx) { + return ctx->pe; +} + +void ggml_openshmem_eval_init( + struct ggml_openshmem_context * ctx, + int * n_tokens, + int * n_past, + int * n_threads) { + UNUSED(ctx); + + uint8_t * dst_symmetric_comm_structure = + ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe)+sizeof(uint64_t)+sizeof(uint64_t); + int64_t * dst_symmetric_comm_offset = + (int64_t*)(dst_symmetric_comm_structure); + + // synchronize the worker node parameters with the root node + shmem_barrier_all(); + + memcpy((int*)dst_symmetric_comm_offset, n_tokens, sizeof(int)); + memcpy(((int*)dst_symmetric_comm_offset)+sizeof(int), n_past, sizeof(int)); + memcpy(((int*)dst_symmetric_comm_offset)+sizeof(int)+sizeof(int), n_threads, sizeof(int)); + + shmem_int32_broadcast(SHMEM_TEAM_WORLD, (int*)dst_symmetric_comm_offset, (int*)dst_symmetric_comm_offset, 3, 0); + + memcpy(n_tokens, ((int*)dst_symmetric_comm_offset), sizeof(int)); + memcpy(n_past, ((int*)dst_symmetric_comm_offset)+sizeof(int), sizeof(int)); + memcpy(n_threads, ((int*)dst_symmetric_comm_offset)+sizeof(int)+sizeof(int), sizeof(int)); + + shmem_quiet(); +} + +static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { + struct ggml_tensor * t = ggml_graph_get_tensor(gf, name); + if (t == NULL) { + fprintf(stderr, "%s: tensor %s not found\n", __func__, name); + return -1; + } + + for (int i = 0; i < gf->n_nodes; i++) { + if (gf->nodes[i] == t) { + return i; + } + } + + fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name); + return -1; +} + +/* + * The OpenSHMEM mechanism used in this application reflects a message passing model; this is a byproduct of OpenSHMEM's symmetric memory requirements. + * Care has been taken to limit the number of branches made in send/recv and the amount of two-sided communication. Memory consistency maybe an issue + * which is why a `shmem_fence` is placed at the end of both send/recv. + * + */ +static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int dst_pe) { + + const int64_t symmetric_comm_structure_size = + ctx->symmetric_comm_structure_size; + + uint64_t * my_recv_signal = + ((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*ctx->pe); + uint64_t * dst_recv_signal = + ((uint64_t*)my_recv_signal)+sizeof(uint64_t); + + uint8_t * dst_symmetric_comm_structure = + ((uint8_t*)dst_recv_signal)+sizeof(uint64_t); + int64_t * dst_symmetric_comm_offset = + (int64_t*)(dst_symmetric_comm_structure); + int64_t * dst_symmetric_comm_length = + ((int64_t*)dst_symmetric_comm_offset)+sizeof(int64_t); + uint8_t * dst_symmetric_comm_buffer = + ((uint8_t*)dst_symmetric_comm_length)+sizeof(int64_t); + + const int64_t nelements = ggml_nelements(t); + int64_t xmt_size = 0; + + switch (t->type) { + case GGML_TYPE_I32: + xmt_size = nelements * sizeof(int32_t); + break; + case GGML_TYPE_F32: + xmt_size = nelements * sizeof(int32_t); + break; + default: GGML_ASSERT(false && "not implemented"); + } + + int64_t init_segments = (xmt_size / OPENSHMEM_SYMMETRIC_BUFFER_SIZE); + int64_t xmt_amount [2] = { + OPENSHMEM_SYMMETRIC_BUFFER_SIZE, + xmt_size - (OPENSHMEM_SYMMETRIC_BUFFER_SIZE * init_segments) + }; + int64_t xmt_byte_offset = 0; + int64_t xmt_byte_amount = 0; + + const int64_t total_loop_count = + init_segments + !( xmt_amount[1] < 1); + + memcpy( + dst_symmetric_comm_offset, + &total_loop_count, + sizeof(int64_t) + ); + + shmem_int64_put_signal( + dst_symmetric_comm_offset, + dst_symmetric_comm_offset, + sizeof(int64_t), + dst_recv_signal, + 1, + SHMEM_SIGNAL_SET, + dst_pe + ); + + shmem_wait_until( + my_recv_signal, + SHMEM_CMP_EQ, + 1 + ); + + (*my_recv_signal) = 0; + + xmt_byte_amount = xmt_amount[0 == (total_loop_count-1)]; + + for(int32_t i = 0; i < total_loop_count; ++i) { + memcpy(dst_symmetric_comm_offset, &xmt_byte_offset, sizeof(int64_t)); + memcpy(dst_symmetric_comm_length, &xmt_byte_amount, sizeof(int64_t)); + memcpy(dst_symmetric_comm_buffer, ((uint8_t*)t->data)+xmt_byte_offset, xmt_byte_amount); + + shmem_uint8_put_signal( + dst_symmetric_comm_structure, + dst_symmetric_comm_structure, + symmetric_comm_structure_size, + dst_recv_signal, + 1, + SHMEM_SIGNAL_SET, + dst_pe + ); + + shmem_wait_until( + my_recv_signal, + SHMEM_CMP_EQ, + 1 + ); + + (*my_recv_signal) = 0; + + xmt_byte_offset += xmt_byte_amount; + xmt_amount[1] -= xmt_byte_amount; + xmt_byte_amount = xmt_amount[i == (total_loop_count-1)]; + } + + shmem_fence(); +} + +static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int src_pe) { + + const int64_t symmetric_comm_structure_size = + ctx->symmetric_comm_structure_size; + + uint64_t * src_recv_signal = + ((uint64_t*)ctx->symmetric_comm_structure)+(symmetric_comm_structure_size*src_pe); + uint64_t * my_recv_signal = + ((uint64_t*)src_recv_signal)+sizeof(uint64_t); + + uint8_t * src_symmetric_comm_structure = + ((uint8_t*)my_recv_signal)+sizeof(uint64_t); + int64_t * src_symmetric_comm_offset = + (int64_t*)(src_symmetric_comm_structure); + int64_t * src_symmetric_comm_length = + ((int64_t*)src_symmetric_comm_offset)+sizeof(int64_t); + uint8_t * src_symmetric_comm_buffer = + ((uint8_t*)src_symmetric_comm_length)+sizeof(int64_t); + + int64_t total_loop_count = 0; + + shmem_wait_until( + my_recv_signal, + SHMEM_CMP_EQ, + 1 + ); + (*my_recv_signal) = 0; + + memcpy(&total_loop_count, src_symmetric_comm_offset, sizeof(int64_t)); + shmem_uint8_put_signal( + src_symmetric_comm_structure, + src_symmetric_comm_structure, + 0, + src_recv_signal, + 1, + SHMEM_SIGNAL_SET, + src_pe + ); + + for(int32_t i = 0; i < total_loop_count; ++i) { + shmem_wait_until( + my_recv_signal, + SHMEM_CMP_EQ, + 1 + ); + (*my_recv_signal) = 0; + + memcpy( + ((uint8_t*)t->data)+(*src_symmetric_comm_offset), + src_symmetric_comm_buffer+(*src_symmetric_comm_offset), + (*src_symmetric_comm_length) + ); + + shmem_uint8_put_signal( + src_symmetric_comm_structure, + src_symmetric_comm_structure, + 0, + src_recv_signal, + 1, + SHMEM_SIGNAL_SET, + src_pe + ); + } + + shmem_fence(); +} + +// TODO: there are many improvements that can be done to this implementation +void ggml_openshmem_graph_compute_pre( + struct ggml_openshmem_context * ctx_openshmem, + struct ggml_cgraph * gf, + int n_layers) { + const int openshmem_pe = ctx_openshmem->pe; + const int openshmem_size = ctx_openshmem->n_pes; + + struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens"); + if (inp_tokens == NULL) { + fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__); + return; + } + + struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0"); + if (inp0 == NULL) { + fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__); + return; + } + + GGML_ASSERT(inp0 == gf->nodes[0]); + + // distribute the compute graph into slices across the MPI nodes + // + // the main node (0) processes the last layers + the remainder of the compute graph + // and is responsible to pass the input tokens to the first node (1) + // + // node 1: [( 0) * n_per_node, ( 1) * n_per_node) + // node 2: [( 1) * n_per_node, ( 2) * n_per_node) + // ... + // node n-1: [(n-2) * n_per_node, (n-1) * n_per_node) + // node 0: [(n-1) * n_per_node, n_nodes) + // + { + struct ggml_tensor * input_tokens[2] = { inp_tokens, inp0 }; + + if (openshmem_pe > 0) { + ggml_openshmem_tensor_recv(ctx_openshmem, input_tokens[openshmem_pe == 1], openshmem_pe-1); + } + else if (openshmem_size > 1) { + // node 0 sends the input tokens to node 1 + ggml_openshmem_tensor_send(ctx_openshmem, input_tokens[0], 1); + + // recv the output data from the last node + ggml_openshmem_tensor_recv(ctx_openshmem, input_tokens[1], openshmem_size - 1); + } + } + + { + const int n_per_node = (n_layers + (openshmem_size - 1)) / openshmem_size; + + const int openshmem_idx = openshmem_pe > 0 ? openshmem_pe - 1 : openshmem_size - 1; + + const int il0 = (openshmem_idx + 0) * n_per_node; + const int il1 = MIN(n_layers, (openshmem_idx + 1) * n_per_node); + + char name_l0[GGML_MAX_NAME]; + char name_l1[GGML_MAX_NAME]; + + snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0); + snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); + + const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0); + const int idx_l1 = openshmem_pe > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes; + + if (idx_l0 < 0 || idx_l1 < 0) { + fprintf(stderr, "%s: layer input nodes not found\n", __func__); + return; + } + + // attach the input data to all nodes that need it + // TODO: not great - should be able to do this without modifying the compute graph (see next TODO below) + for (int i = idx_l0; i < idx_l1; i++) { + if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) { + gf->nodes[i]->src[0] = inp0; + } + if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) { + gf->nodes[i]->src[1] = inp0; + } + } + + // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph + for (int i = 1; i < idx_l1 - idx_l0; i++) { + gf->nodes[i] = gf->nodes[idx_l0 + i]; + gf->grads[i] = gf->grads[idx_l0 + i]; + } + + // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node + if (openshmem_idx != 0) { + gf->nodes[0]->op = GGML_OP_NONE; + } + + gf->n_nodes = idx_l1 - idx_l0; + + //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, openshmem_pe, gf->n_nodes, il0, il1); + } +} + +void ggml_openshmem_graph_compute_post( + struct ggml_openshmem_context * ctx_openshmem, + struct ggml_cgraph * gf, + int n_layers) { + UNUSED(n_layers); + + const int openshmem_pe = ctx_openshmem->pe; + const int openshmem_size = ctx_openshmem->n_pes; + + // send the output data to the next node + if (openshmem_pe > 0) { + ggml_openshmem_tensor_send(ctx_openshmem, gf->nodes[gf->n_nodes - 1], (openshmem_pe + 1) % openshmem_size); + } +} diff --git a/ggml-oshmem.h b/ggml-oshmem.h new file mode 100644 index 000000000..ea88585ad --- /dev/null +++ b/ggml-oshmem.h @@ -0,0 +1,43 @@ +#pragma once +#ifndef __LLAMA_CPP_GGML_OSHMEM_H__ +#define __LLAMA_CPP_GGML_OSHMEM_H__ + +struct ggml_context; +struct ggml_tensor; +struct ggml_cgraph; + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_openshmem_context; + +void ggml_openshmem_backend_init(void); +void ggml_openshmem_backend_free(void); + +struct ggml_openshmem_context * ggml_openshmem_init(void); +void ggml_openshmem_free(struct ggml_openshmem_context * ctx); + +int ggml_openshmem_pe(struct ggml_openshmem_context * ctx); + +void ggml_openshmem_eval_init( + struct ggml_openshmem_context * ctx_openshmem, + int * n_tokens, + int * n_past, + int * n_threads); + +void ggml_openshmem_graph_compute_pre( + struct ggml_openshmem_context * ctx_openshmem, + struct ggml_cgraph * gf, + int n_layers); + +void ggml_openshmem_graph_compute_post( + struct ggml_openshmem_context * ctx_openshmem, + struct ggml_cgraph * gf, + int n_layers); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/llama.cpp b/llama.cpp index 2c5983c67..d95c7b046 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19,6 +19,9 @@ #ifdef GGML_USE_MPI # include "ggml-mpi.h" #endif +#ifdef GGML_USE_OPENSHMEM +# include "ggml-oshmem.h" +#endif #ifndef QK_K # ifdef GGML_QKK_64 # define QK_K 64 @@ -1675,6 +1678,11 @@ struct llama_context { #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; #endif + +#ifdef GGML_USE_OPENSHMEM + ggml_openshmem_context * ctx_oshmem = NULL; +#endif + }; // @@ -6289,6 +6297,12 @@ static int llama_decode_internal( ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); #endif +#if GGML_USE_OPENSHMEM + const int64_t n_layer = hparams.n_layer; + ggml_openshmem_graph_compute_pre(lctx.ctx_oshmem, gf, n_layer); +#endif + + #ifdef GGML_USE_METAL if (ggml_backend_is_metal(lctx.backend_metal)) { ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); @@ -6306,6 +6320,10 @@ static int llama_decode_internal( ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); #endif +#if GGML_USE_OPENSHEM + ggml_openshmem_graph_compute_post(lctx.ctx_oshmem, gf, n_layer); +#endif + // update the kv ring buffer { if (kv_self.has_shift) { @@ -9330,12 +9348,21 @@ void llama_backend_init(bool numa) { #ifdef GGML_USE_MPI ggml_mpi_backend_init(); #endif + +#ifdef GGML_USE_OPENSHMEM + ggml_openshmem_backend_init(); +#endif + } void llama_backend_free(void) { #ifdef GGML_USE_MPI ggml_mpi_backend_free(); #endif +#ifdef GGML_USE_OPENSHMEM + ggml_openshmem_backend_free(); +#endif + } int64_t llama_time_us(void) { @@ -9577,6 +9604,20 @@ struct llama_context * llama_new_context_with_model( } #endif +#ifdef GGML_USE_OPENSHMEM + ctx->ctx_oshmem = ggml_openshmem_init(); + + if (ggml_openshmem_pe(ctx->ctx_oshmem) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + // TODO: needs fix after #3228 + GGML_ASSERT(false && "not implemented"); + //const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); + //while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + llama_backend_free(); + exit(1); + } +#endif + return ctx; }