diff --git a/CMakeLists.txt b/CMakeLists.txt index e3cd43ab3..e58041af7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -95,6 +95,7 @@ option(LLAMA_CLBLAST "llama: use CLBlast" option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) option(LLAMA_MPI "llama: use MPI" OFF) +option(LLAMA_OPENSHMEM "llama: use OpenSHMEM" OFF) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) @@ -344,6 +345,29 @@ if (LLAMA_MPI) endif() endif() +if (LLAMA_OPENSHMEM) + cmake_minimum_required(VERSION 3.10) + include(cmake/FindOpenSHMEM.cmake) + + setup_openshmem() + + if (OPENSHMEM_FOUND) + message(STATUS "OpenSHMEM found") + set(GGML_HEADERS_OPENSHMEM ggml-oshmem.h) + set(GGML_SOURCES_OPENSHMEM ggml-oshmem.c ggml-oshmem.h) + add_compile_definitions(GGML_USE_OPENSHMEM) + + if (NOT MSVC) + add_compile_options(-Wno-cast-qual) + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${OPENSHMEM_LDFLAGS}) + string(REPLACE "-I" "" OPENSHMEM_CFLAGS ${OPENSHMEM_CFLAGS}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${OPENSHMEM_CFLAGS}) + else() + message(WARNING "OpenSHMEM not found") + endif() +endif() + if (LLAMA_CLBLAST) find_package(CLBlast) if (CLBlast_FOUND) @@ -722,6 +746,7 @@ add_library(ggml OBJECT ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} + ${GGML_SOURCES_OPENSHMEM} ${GGML_HEADERS_OPENSHMEM} ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} ) diff --git a/cmake/FindOpenSHMEM.cmake b/cmake/FindOpenSHMEM.cmake new file mode 100644 index 000000000..550d7639f --- /dev/null +++ b/cmake/FindOpenSHMEM.cmake @@ -0,0 +1,917 @@ +# Copyright (c) 2019-2023 Ste||ar Group +# +# SPDX-License-Identifier: BSL-1.0 +# Distributed under the Boost Software License, Version 1.0. (See accompanying +# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +# +macro(setup_openshmem) + + if(NOT TARGET PkgConfig::OPENSHMEM) + + set(OPENSHMEM_PC "") + + find_package(MPI) + if (LLAMA_MPI AND MPI_C_FOUND) + set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:${MPI_LIBDIR}/pkgconfig") + + set(OPENSHMEM_PC "oshmem") + pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC}) + + if(NOT OPENSHMEM_FOUND) + find_program(OSHMEM_INFO NAMES oshmem_info ompi_info REQUIRED) + + if(NOT OSHMEM_INFO) + message( + FATAL_ERROR + "oshmem_info and/or ompi_info not found! pkg-config cannot find OpenMPI's `${OPENSHMEM_PC}.pc`" + ) + endif() + + set(OSHMEM_INFO_OUTPUT + "${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_stdout.log" + ) + set(OSHMEM_INFO_ERROR + "${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_error.log" + ) + + execute_process( + COMMAND bash -c "${OSHMEM_INFO} --path libdir" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE OSHMEM_INFO_STATUS + OUTPUT_FILE ${OSHMEM_INFO_OUTPUT} + ERROR_FILE ${OSHMEM_INFO_ERROR} + ) + + if(OSHMEM_INFO_STATUS) + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Program status code: ${OSHMEM_INFO_STATUS}" + ) + endif() + + file(READ ${OSHMEM_INFO_OUTPUT} OSHMEM_INFO_OUTPUT_CONTENT) + + if(NOT DEFINED OSHMEM_INFO_OUTPUT_CONTENT) + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_ERROR}\n${OSHMEM_INFO_OUTPUT_CONTENT}" + ) + endif() + + if("${OSHMEM_INFO_OUTPUT_CONTENT}" STREQUAL "") + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_ERROR}\n${OSHMEM_INFO_OUTPUT_CONTENT}" + ) + endif() + + string(REGEX MATCH "(\/.*)" OSHMEM_LIBDIR_PATH + ${OSHMEM_INFO_OUTPUT_CONTENT} + ) + + string(STRIP ${OSHMEM_LIBDIR_PATH} OSHMEM_LIBDIR_PATH) + + set(ENV{PKG_CONFIG_PATH} + "$ENV{PKG_CONFIG_PATH}:${OSHMEM_LIBDIR_PATH}/pkgconfig" + ) + + pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC}) + + if(NOT OPENSHMEM_FOUND) + + set(OSHMEM_INFO_INCOUTPUT + "${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_stdout_inc.log" + ) + set(OSHMEM_INFO_INCERROR + "${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_error_inc.log" + ) + + execute_process( + COMMAND bash -c "${OSHMEM_INFO} --path incdir" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE OSHMEM_INFO_INCSTATUS + OUTPUT_FILE ${OSHMEM_INFO_INCOUTPUT} + ERROR_FILE ${OSHMEM_INFO_INCERROR} + ) + + if(OSHMEM_INFO_INCSTATUS) + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Program status code: ${OSHMEM_INFO_INCSTATUS}" + ) + endif() + file(READ ${OSHMEM_INFO_INCOUTPUT} OSHMEM_INFO_OUTPUT_INCCONTENT) + + if(NOT DEFINED OSHMEM_INFO_OUTPUT_INCCONTENT) + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_INCERROR}" + ) + endif() + + if("${OSHMEM_INFO_OUTPUT_INCCONTENT}" STREQUAL "") + message( + FATAL_ERROR + "${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_INCERROR}\n${OSHMEM_INFO_OUTPUT_INCCONTENT}" + ) + endif() + + string(REGEX MATCH "(\/.*)" OSHMEM_INCDIR_PATH + ${OSHMEM_INFO_OUTPUT_INCCONTENT} + ) + + string(STRIP ${OSHMEM_INCDIR_PATH} OSHMEM_INCDIR_PATH) + + set(OPENSHMEM_CFLAGS + "-I${OSHMEM_INCDIR_PATH} -pthread -I${OSHMEM_LIBDIR_PATH}" + ) + set(OPENSHMEM_LDFLAGS "-loshmem") + set(OPENSHMEM_LIBRARY_DIRS "${OSHMEM_LIBDIR_PATH}") + + add_library(PkgConfig::OPENSHMEM INTERFACE IMPORTED GLOBAL) + + set(OPENSHMEM_FOUND ON) + endif() + endif() + else() + + include(cmake/FindOpenShmemPmi.cmake) + + set(PMI_AUTOCONF_OPTS "") + if(NOT PMI_LIBRARY OR NOT PMI_FOUND) + set(PMI_AUTOCONF_OPTS "--enable-pmi-simple") + else() + string(REGEX MATCH "(^\/[^\/]+)" PMI_INCLUDE_DIR_ROOT_PATH + ${PMI_INCLUDE_DIR} + ) + string(REGEX MATCH "(^\/[^\/]+)" PMI_LIBRARY_ROOT_PATH ${PMI_LIBRARY}) + set(PMI_AUTOCONF_OPTS + "--with-pmi=${PMI_INCLUDE_DIR_ROOT_PATH} --with-pmi-libdir=${PMI_LIBRARY_ROOT_PATH}" + ) + endif() + + set(OPENSHMEM_PC "osss-ucx") + + pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC}) + if(NOT OPENSHMEM_FOUND) + set(OPENSHMEM_PC "sandia-openshmem") + pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC}) + endif() + endif() + endif() + + if(OPENSHMEM_CFLAGS) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(FLAG_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_CFLAGS}) + string(FIND "${X}" "--param" PARAM_FOUND) + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + endif() + endforeach() + + list(LENGTH OPENSHMEM_CFLAGS IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_CFLAGS NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_CFLAGS "${X}") + endforeach() + endif() + + if(OPENSHMEM_CFLAGS_OTHER) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(FLAG_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_CFLAGS_OTHER}) + string(FIND "${X}" "--param" PARAM_FOUND) + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + endif() + endforeach() + + list(LENGTH OPENSHMEM_CFLAGS_OTHER IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_CFLAGS_OTHER NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_CFLAGS_OTHER "${X}") + endforeach() + endif() + + if(OPENSHMEM_LDFLAGS) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(DIRIDX 0) + set(SKIP 0) + set(FLAG_LIST "") + set(DIR_LIST "") + set(LIB_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_LDFLAGS}) + string(FIND "${X}" "--param" PARAM_FOUND) + string(FIND "${X}" "-lsma" IDX) + string(FIND "${X}" "-l" LIDX) + string(FIND "${X}" "-L" DIRIDX) + string(FIND "${X}" "-Wl" SKIP) + + if("${SKIP}" EQUAL "-1") + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IDX}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-l" "" TMPSTR "${X}") + list(APPEND LIB_LIST "${TMPSTR}") + set(IDX 0) + elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + list(APPEND FLAG_LIST "${X}") + endif() + if(NOT "${DIRIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-L" "" TMPSTR "${X}") + list(APPEND DIR_LIST "${TMPSTR}") + endif() + endif() + endforeach() + + set(IDX 0) + list(LENGTH LIB_LIST IDX) + + if(NOT "${IDX}" EQUAL "0") + set(IDX 0) + + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(NEWLINK "SHELL:-Wl,--whole-archive + " + ) + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(APPEND NEWLINK " + -Wl,--no-whole-archive" + ) + string(FIND "SHELL:-Wl,--whole-archive + -Wl,--no-whole-archive" "${NEWLINK}" IDX + ) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + if(APPLE) + set(NEWLINK "SHELL:-Wl,-force_load,") + else() + set(NEWLINK "SHELL: + " + ) + endif() + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(FIND "SHELL:" "${NEWLINK}" IDX) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + endif() + endif() + endif() + + if(OPENSHMEM_LDFLAGS_OTHER) + unset(FOUND_LIB) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(SKIP 0) + set(IDX 0) + set(DIRIDX 0) + set(FLAG_LIST "") + set(DIR_LIST "") + set(LIB_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_LDFLAGS_OTHER}) + string(FIND "${X}" "--param" PARAM_FOUND) + string(FIND "${X}" "-lsma" IDX) + string(FIND "${X}" "-L" DIRIDX) + string(FIND "${X}" "-Wl" SKIP) + + if("${SKIP}" EQUAL "-1") + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IDX}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-l" "" TMPSTR "${X}") + list(APPEND LIB_LIST "${TMPSTR}") + set(IDX 0) + elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + list(APPEND FLAG_LIST "${X}") + endif() + if(NOT "${DIRIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-L" "" TMPSTR "${X}") + list(APPEND DIR_LIST "${TMPSTR}") + endif() + endif() + endforeach() + + set(IDX 0) + list(LENGTH OPENSHMEM_LDFLAGS_OTHER IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_LDFLAGS_OTHER NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_LDFLAGS_OTHER "${X}") + endforeach() + + set(IDX 0) + list(LENGTH LIB_LIST IDX) + if(NOT "${IDX}" EQUAL "0") + set(IDX 0) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(NEWLINK "SHELL:-Wl,--whole-archive + " + ) + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(APPEND NEWLINK " + -Wl,--no-whole-archive" + ) + + string(FIND "SHELL:-Wl,--whole-archive + -Wl,--no-whole-archive" "${NEWLINK}" IDX + ) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS_OTHER "${NEWLINK}") + endif() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + if(APPLE) + set(NEWLINK "SHELL:-Wl,-force_load,") + else() + set(NEWLINK "SHELL: + " + ) + endif() + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(FIND "SHELL:" "${NEWLINK}" IDX) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + endif() + endif() + endif() + + if(OPENSHMEM_STATIC_CFLAGS) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(FLAG_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_STATIC_CFLAGS}) + string(FIND "${X}" "--param" PARAM_FOUND) + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + endif() + endforeach() + + list(LENGTH OPENSHMEM_STATIC_CFLAGS IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_STATIC_CFLAGS NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_STATIC_CFLAGS "${X}") + endforeach() + endif() + + if(OPENSHMEM_STATIC_CFLAGS_OTHER) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(IDX 0) + set(FLAG_LIST "") + foreach(X IN ITEMS ${OPENSHMEM_STATIC_CFLAGS_OTHER}) + string(FIND "${X}" "--param" PARAM_FOUND) + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + endif() + endforeach() + + list(LENGTH OPENSHMEM_STATIC_CFLAGS_OTHER IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_STATIC_CFLAGS_OTHER NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_STATIC_CFLAGS_OTHER "${X}") + endforeach() + endif() + + if(OPENSHMEM_STATIC_LDFLAGS) + unset(FOUND_LIB) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(SKIP 0) + set(IDX 0) + set(DIRIDX 0) + set(FLAG_LIST "") + set(DIR_LIST "") + set(LIB_LIST "") + foreach(X IN ITEMS ${OPENSHMEM_STATIC_LDFLAGS}) + string(FIND "${X}" "--param" PARAM_FOUND) + if("${HPX_WITH_PARCELPORT_OPENSHMEM_CONDUIT}" STREQUAL "mpi") + string(FIND "${X}" "-loshmem" IDX) + else() + string(FIND "${X}" "-lsma" IDX) + endif() + string(FIND "${X}" "-L" DIRIDX) + string(FIND "${X}" "-Wl" SKIP) + + if("${SKIP}" EQUAL "-1") + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IDX}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-l" "" TMPSTR "${X}") + list(APPEND LIB_LIST "${TMPSTR}") + set(IDX 0) + elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + list(APPEND FLAG_LIST "${X}") + endif() + if(NOT "${DIRIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-L" "" TMPSTR "${X}") + list(APPEND DIR_LIST "${TMPSTR}") + endif() + endif() + endforeach() + set(IDX 0) + list(LENGTH OPENSHMEM_STATIC_LDFLAGS IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_STATIC_LDFLAGS NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_STATIC_LDFLAGS "${X}") + endforeach() + + set(IDX 0) + list(LENGTH LIB_LIST IDX) + if(NOT "${IDX}" EQUAL "0") + set(IDX 0) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(NEWLINK "SHELL:-Wl,--whole-archive + " + ) + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(APPEND NEWLINK " + -Wl,--no-whole-archive" + ) + + string(FIND "SHELL:-Wl,--whole-archive + -Wl,--no-whole-archive" "${NEWLINK}" IDX + ) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_STATIC_LDFLAGS "${NEWLINK}") + endif() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + if(APPLE) + set(NEWLINK "SHELL:-Wl,-force_load,") + else() + set(NEWLINK "SHELL: + " + ) + endif() + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(FIND "SHELL:" "${NEWLINK}" IDX) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + endif() + endif() + endif() + + if(OPENSHMEM_STATIC_LDFLAGS_OTHER) + unset(FOUND_LIB) + set(IS_PARAM "0") + set(PARAM_FOUND "0") + set(NEWPARAM "") + set(SKIP 0) + set(IDX 0) + set(DIRIDX 0) + set(FLAG_LIST "") + set(DIR_LIST "") + set(LIB_LIST "") + + foreach(X IN ITEMS ${OPENSHMEM_STATIC_LDFLAGS_OTHER}) + string(FIND "${X}" "--param" PARAM_FOUND) + if("${HPX_WITH_PARCELPORT_OPENSHMEM_CONDUIT}" STREQUAL "mpi") + string(FIND "${X}" "-loshmem" IDX) + else() + string(FIND "${X}" "-lsma" IDX) + endif() + string(FIND "${X}" "-L" DIRIDX) + string(FIND "${X}" "-Wl" SKIP) + + if("${SKIP}" EQUAL "-1") + if(NOT "${PARAM_FOUND}" EQUAL "-1") + set(IS_PARAM "1") + set(NEWPARAM "SHELL:${X}") + endif() + if("${PARAM_FOUND}" EQUAL "-1" + AND "${IDX}" EQUAL "-1" + AND "${IS_PARAM}" EQUAL "0" + OR "${IS_PARAM}" EQUAL "-1" + ) + list(APPEND FLAG_LIST "${X}") + set(IS_PARAM "0") + elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1") + list(APPEND FLAG_LIST "${NEWPARAM} + ${X}" + ) + set(NEWPARAM "") + set(IS_PARAM "0") + elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-l" "" TMPSTR "${X}") + list(APPEND LIB_LIST "${TMPSTR}") + set(IDX 0) + elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1") + list(APPEND FLAG_LIST "${X}") + endif() + if(NOT "${DIRIDX}" EQUAL "-1") + set(TMPSTR "") + string(REPLACE "-L" "" TMPSTR "${X}") + list(APPEND DIR_LIST "${TMPSTR}") + endif() + endif() + endforeach() + + set(IDX 0) + list(LENGTH OPENSHMEM_STATIC_LDFLAGS_OTHER IDX) + foreach(X RANGE ${IDX}) + list(POP_FRONT OPENSHMEM_STATIC_LDFLAGS_OTHER NEWPARAM) + endforeach() + + foreach(X IN ITEMS ${FLAG_LIST}) + list(APPEND OPENSHMEM_STATIC_LDFLAGS_OTHER "${X}") + endforeach() + + set(IDX 0) + list(LENGTH LIB_LIST IDX) + if(NOT "${IDX}" EQUAL "0") + set(IDX 0) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(NEWLINK "SHELL:-Wl,--whole-archive + " + ) + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + + message(STATUS "${FOUND_LIB} + ${X}" + ) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(APPEND NEWLINK " + -Wl,--no-whole-archive" + ) + string(FIND "SHELL:-Wl,--whole-archive + -Wl,--no-whole-archive" "${NEWLINK}" IDX + ) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_STATIC_LDFLAGS_OTHER "${NEWLINK}") + endif() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + if(APPLE) + set(NEWLINK "SHELL:-Wl,-force_load,") + else() + set(NEWLINK "SHELL: + " + ) + endif() + foreach(X IN ITEMS ${LIB_LIST}) + set(DIRSTR "") + string(REPLACE ";" " + " DIRSTR "${DIR_LIST}" + ) + foreach(Y IN ITEMS ${DIR_LIST}) + find_library( + FOUND_LIB + NAMES ${X} "lib${X}" "lib${X}.a" + PATHS ${Y} + HINTS ${Y} NO_CACHE + NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH + ) + + list(LENGTH FOUND_LIB IDX) + if(NOT "${IDX}" EQUAL "0") + string(APPEND NEWLINK "${FOUND_LIB}") + set(FOUND_LIB "") + endif() + endforeach() + endforeach() + string(FIND "SHELL:" "${NEWLINK}" IDX) + if("${IDX}" EQUAL "-1") + list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}") + endif() + endif() + endif() + endif() + + if(OPENSHMEM_DIR) + list(TRANSFORM OPENSHMEM_CFLAGS + REPLACE "${OPENSHMEM_DIR}/install" + "$" + ) + list(TRANSFORM OPENSHMEM_LDFLAGS + REPLACE "${OPENSHMEM_DIR}/install" + "$" + ) + list(TRANSFORM OPENSHMEM_LIBRARY_DIRS + REPLACE "${OPENSHMEM_DIR}/install" + "$" + ) + + message(STATUS "OPENSHMEM_CFLAGS:\t${OPENSHMEM_CFLAGS}") + message(STATUS "OPENSHMEM_LDFLAGS:\t${OPENSHMEM_LDFLAGS}") + message(STATUS "OPENSHMEM_LIBRARY_DIRS:\t${OPENSHMEM_LIBRARY_DIRS}") + + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_COMPILE_OPTIONS + "${OPENSHMEM_CFLAGS}" + ) + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_OPTIONS + "${OPENSHMEM_LDFLAGS}" + ) + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_DIRECTORIES + "${OPENSHMEM_LIBRARY_DIRS}" + ) + set(OPENSHMEM_FOUND ON) + else() + message(STATUS "OPENSHMEM_CFLAGS:\t${OPENSHMEM_CFLAGS}") + message(STATUS "OPENSHMEM_LDFLAGS:\t${OPENSHMEM_LDFLAGS}") + message(STATUS "OPENSHMEM_LIBRARY_DIRS:\t${OPENSHMEM_LIBRARY_DIRS}") + + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_COMPILE_OPTIONS + "${OPENSHMEM_CFLAGS}" + ) + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_OPTIONS + "${OPENSHMEM_LDFLAGS}" + ) + set_target_properties( + PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_DIRECTORIES + "${OPENSHMEM_LIBRARY_DIRS}" + ) + set(OPENSHMEM_FOUND ON) + endif() +endmacro() diff --git a/cmake/FindOpenShmemPmi.cmake b/cmake/FindOpenShmemPmi.cmake new file mode 100644 index 000000000..5f6814a50 --- /dev/null +++ b/cmake/FindOpenShmemPmi.cmake @@ -0,0 +1,65 @@ +# Copyright (c) 2023 Christopher Taylor +# +# SPDX-License-Identifier: BSL-1.0 +# Distributed under the Boost Software License, Version 1.0. (See accompanying +# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +find_package(PkgConfig QUIET) +# look for cray pmi... +pkg_check_modules(PC_PMI_CRAY QUIET cray-pmi) +# look for the rest if we couldn't find the cray package +if(NOT PC_PMI_CRAY_FOUND) + pkg_check_modules(PC_PMI QUIET pmi) +endif() + +find_path( + PMI_INCLUDE_DIR pmi2.h + HINTS ${PMI_ROOT} + ENV + PMI_ROOT + ${PMI_DIR} + ENV + PMI_DIR + ${PC_PMI_CRAY_INCLUDEDIR} + ${PC_PMI_CRAY_INCLUDE_DIRS} + ${PC_PMI_INCLUDEDIR} + ${PC_PMI_INCLUDE_DIRS} + PATH_SUFFIXES include +) + +find_library( + PMI_LIBRARY + NAMES pmi + HINTS ${PMI_ROOT} + ENV + PMI_ROOT + ${PC_PMI_CRAY_LIBDIR} + ${PC_PMI_CRAY_LIBRARY_DIRS} + ${PC_PMI_LIBDIR} + ${PC_PMI_LIBRARY_DIRS} + PATH_SUFFIXES lib lib64 +) + +# Set PMI_ROOT in case the other hints are used +if(PMI_ROOT) + # The call to file is for compatibility with windows paths + file(TO_CMAKE_PATH ${PMI_ROOT} PMI_ROOT) +elseif("$ENV{PMI_ROOT}") + file(TO_CMAKE_PATH $ENV{PMI_ROOT} PMI_ROOT) +else() + file(TO_CMAKE_PATH "${PMI_INCLUDE_DIR}" PMI_INCLUDE_DIR) + string(REPLACE "/include" "" PMI_ROOT "${PMI_INCLUDE_DIR}") +endif() + +if(NOT PMI_LIBRARY OR NOT PMI_INCLUDE_DIR) + set(PMI_FOUND=OFF) + return() +endif() + +# hpx_error( "PMI_LIBRARY OR PMI_INCLUDE_DIR not found, please install PMI or +# set \ the right PMI_ROOT path" ) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(PMI DEFAULT_MSG PMI_LIBRARY PMI_INCLUDE_DIR) + +mark_as_advanced(PMI_ROOT PMI_LIBRARY PMI_INCLUDE_DIR) diff --git a/ggml-oshmem.c b/ggml-oshmem.c new file mode 100644 index 000000000..6acc3b5d4 --- /dev/null +++ b/ggml-oshmem.c @@ -0,0 +1,346 @@ +#include "ggml-oshmem.h" + +#include "ggml.h" + +#include + +#include +#include +#include + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +#define UNUSED GGML_UNUSED + +#define OPENSHMEM_SYMMETRIC_BUFFER_SIZE 4096 + +struct ggml_openshmem_context { + int pe; + int n_pes; + int64_t symmetric_buffer_size; + int64_t symmetric_comm_structure_size; + uint8_t * symmetric_comm_structure; + long * recv_signal; +}; + +void ggml_openshmem_backend_init(void) { + shmem_init(); +} + +void ggml_openshmem_backend_free(void) { + shmem_finalize(); +} + +struct ggml_openshmem_context * ggml_openshmem_init(void) { + struct ggml_openshmem_context * ctx = calloc(1, sizeof(struct ggml_openshmem_context)); + + ctx->pe = shmem_my_pe(); + ctx->n_pes = shmem_n_pes(); + + /* + * makes a symmetric heap allocation on all processing elements (processes running this SPMD program) + * + * below is a struct representing the layout of the symmetric allocation: + * + * { + * int64_t offset_in_buffer, + * int64_t length_in_buffer, + * uint8_t buffer[shmem_npes()][OPENSHMEM_SYMMETRIC_BUFFER_SIZE] + * } + * + */ + ctx->symmetric_buffer_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE; + ctx->symmetric_comm_structure_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE + sizeof(int64_t) + sizeof(int64_t); + ctx->symmetric_comm_structure = (uint8_t*)shmem_calloc(1, ctx->n_pes*ctx->symmetric_comm_structure_size); + + /* + * uint8_t signal_byte[shmem_npes()]; + */ + ctx->recv_signal = (long*)shmem_calloc(1, ctx->n_pes*sizeof(long)); + + return ctx; +} + +void ggml_openshmem_free(struct ggml_openshmem_context * ctx) { + free(ctx); +} + +int ggml_openshmem_pe(struct ggml_openshmem_context * ctx) { + return ctx->pe; +} + +void ggml_openshmem_eval_init( + struct ggml_openshmem_context * ctx_openshmem, + int * n_tokens, + int * n_past, + int * n_threads) { + UNUSED(ctx_openshmem); + + // synchronize the worker node parameters with the root node + shmem_barrier_all(); + + shmem_broadcast(SHMEM_TEAM_WORLD, n_tokens, n_tokens, 1, 0); + shmem_broadcast(SHMEM_TEAM_WORLD, n_past, n_tokens, 1, 0); + shmem_broadcast(SHMEM_TEAM_WORLD, n_threads, n_tokens, 1, 0); + + shmem_quiet(); +} + +static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { + struct ggml_tensor * t = ggml_graph_get_tensor(gf, name); + if (t == NULL) { + fprintf(stderr, "%s: tensor %s not found\n", __func__, name); + return -1; + } + + for (int i = 0; i < gf->n_nodes; i++) { + if (gf->nodes[i] == t) { + return i; + } + } + + fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name); + return -1; +} + +static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int dst_pe) { + + const int64_t symmetric_comm_structure_size = + ctx->symmetric_comm_structure_size; + uint8_t * dst_symmetric_comm_structure = + ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe); + int64_t * dst_symmetric_comm_offset = + (int64_t*)(dst_symmetric_comm_structure); + int64_t * dst_symmetric_comm_length = + ((int64_t*)dst_symmetric_comm_offset)+sizeof(int64_t); + uint8_t * dst_symmetric_comm_buffer = + ((uint8_t*)dst_symmetric_comm_length)+sizeof(int64_t); + long * dst_recv_signal = + ctx->recv_signal+dst_pe; + long * my_recv_signal = + ctx->recv_signal+ctx->pe; + + const int64_t nelements = ggml_nelements(t); + int64_t xmt_size = 0; + + switch (t->type) { + case GGML_TYPE_I32: + xmt_size = nelements * sizeof(int32_t); + break; + case GGML_TYPE_F32: + xmt_size = nelements * sizeof(int32_t); + break; + default: GGML_ASSERT(false && "not implemented"); + } + + int64_t count[2] = { (xmt_size / OPENSHMEM_SYMMETRIC_BUFFER_SIZE), 1 }; + const int64_t total_loop_count = count[ count[0] == 0 ]; + + int64_t xmt_amount [2] = { OPENSHMEM_SYMMETRIC_BUFFER_SIZE, xmt_size - (OPENSHMEM_SYMMETRIC_BUFFER_SIZE * count[0]) }; + int64_t xmt_byte_offset = 0; + int64_t xmt_byte_amount = 0; + + memcpy(dst_symmetric_comm_offset, &total_loop_count, sizeof(int64_t)); + + shmem_put_signal( + dst_symmetric_comm_offset, + dst_symmetric_comm_offset, + sizeof(int64_t), + dst_recv_signal, + 1, + SHMEM_SIGNAL_SET, + dst_pe + ); + + shmem_wait_until( + my_recv_signal, + SHMEM_CMP_EQ, + 1 + ); + + (*my_recv_signal) = 0; + + xmt_byte_amount = xmt_amount[0 == (total_loop_count-1)]; + + for(int32_t i = 0; i < total_loop_count; ++i) { + memcpy(dst_symmetric_comm_offset, &xmt_byte_offset, sizeof(int64_t)); + memcpy(dst_symmetric_comm_length, &xmt_byte_amount, sizeof(int64_t)); + memcpy(dst_symmetric_comm_buffer, ((uint8_t*)t->data)+xmt_byte_offset, xmt_byte_amount); + + shmem_put_signal( + dst_symmetric_comm_structure, + dst_symmetric_comm_structure, + symmetric_comm_structure_size, + dst_recv_signal, + 1, + SHMEM_SIGNAL_SET, + dst_pe + ); + + shmem_wait_until( + my_recv_signal, + SHMEM_CMP_EQ, + 1 + ); + + (*my_recv_signal) = 0; + + xmt_byte_offset += xmt_byte_amount; + xmt_amount[1] -= xmt_byte_amount; + xmt_byte_amount = xmt_amount[i == (total_loop_count-1)]; + } + + shmem_fence(); +} + +static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int src_pe) { + + uint8_t * src_symmetric_comm_structure = + ((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*src_pe); + int64_t * src_symmetric_comm_offset = + (int64_t*)(src_symmetric_comm_structure); + int64_t * src_symmetric_comm_length = + ((int64_t*)src_symmetric_comm_offset)+sizeof(int64_t); + uint8_t * src_symmetric_comm_buffer = + ((uint8_t*)src_symmetric_comm_length)+sizeof(int64_t); + long * src_recv_signal = + ctx->recv_signal+src_pe; + long* my_recv_signal = + ctx->recv_signal+ctx->pe; + + int64_t total_loop_count = 0; + + shmem_wait_until(my_recv_signal, SHMEM_CMP_EQ, 1); + (*my_recv_signal) = 0; + + memcpy(src_symmetric_comm_offset, &total_loop_count, sizeof(int64_t)); + shmem_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe); + + for(int32_t i = 0; i < total_loop_count; ++i) { + shmem_wait_until(my_recv_signal, SHMEM_CMP_EQ, 1); + (*my_recv_signal) = 0; + + memcpy( + ((uint8_t*)t->data)+(*src_symmetric_comm_offset), + src_symmetric_comm_buffer+(*src_symmetric_comm_offset), + (*src_symmetric_comm_length) + ); + + shmem_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe); + } + + shmem_fence(); +} + +// TODO: there are many improvements that can be done to this implementation +void ggml_openshmem_graph_compute_pre( + struct ggml_openshmem_context * ctx_openshmem, + struct ggml_cgraph * gf, + int n_layers) { + const int openshmem_rank = ctx_openshmem->pe; + const int openshmem_size = ctx_openshmem->n_pes; + + struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens"); + if (inp_tokens == NULL) { + fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__); + return; + } + + struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0"); + if (inp0 == NULL) { + fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__); + return; + } + + GGML_ASSERT(inp0 == gf->nodes[0]); + + // distribute the compute graph into slices across the MPI nodes + // + // the main node (0) processes the last layers + the remainder of the compute graph + // and is responsible to pass the input tokens to the first node (1) + // + // node 1: [( 0) * n_per_node, ( 1) * n_per_node) + // node 2: [( 1) * n_per_node, ( 2) * n_per_node) + // ... + // node n-1: [(n-2) * n_per_node, (n-1) * n_per_node) + // node 0: [(n-1) * n_per_node, n_nodes) + // + { + struct ggml_tensor * input_tokens[2] = { inp_tokens, inp0 }; + + if (openshmem_rank > 0) { + ggml_openshmem_tensor_recv(ctx_openshmem, input_tokens[openshmem_rank == 1], openshmem_rank-1); + } + else if (openshmem_size > 1) { + // node 0 sends the input tokens to node 1 + ggml_openshmem_tensor_send(ctx_openshmem, input_tokens[0], 1); + + // recv the output data from the last node + ggml_openshmem_tensor_recv(ctx_openshmem, input_tokens[1], openshmem_size - 1); + } + } + + { + const int n_per_node = (n_layers + (openshmem_size - 1)) / openshmem_size; + + const int openshmem_idx = openshmem_rank > 0 ? openshmem_rank - 1 : openshmem_size - 1; + + const int il0 = (openshmem_idx + 0) * n_per_node; + const int il1 = MIN(n_layers, (openshmem_idx + 1) * n_per_node); + + char name_l0[GGML_MAX_NAME]; + char name_l1[GGML_MAX_NAME]; + + snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0); + snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); + + const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0); + const int idx_l1 = openshmem_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes; + + if (idx_l0 < 0 || idx_l1 < 0) { + fprintf(stderr, "%s: layer input nodes not found\n", __func__); + return; + } + + // attach the input data to all nodes that need it + // TODO: not great - should be able to do this without modifying the compute graph (see next TODO below) + for (int i = idx_l0; i < idx_l1; i++) { + if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) { + gf->nodes[i]->src[0] = inp0; + } + if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) { + gf->nodes[i]->src[1] = inp0; + } + } + + // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph + for (int i = 1; i < idx_l1 - idx_l0; i++) { + gf->nodes[i] = gf->nodes[idx_l0 + i]; + gf->grads[i] = gf->grads[idx_l0 + i]; + } + + // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node + if (openshmem_idx != 0) { + gf->nodes[0]->op = GGML_OP_NONE; + } + + gf->n_nodes = idx_l1 - idx_l0; + + //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, openshmem_rank, gf->n_nodes, il0, il1); + } +} + +void ggml_openshmem_graph_compute_post( + struct ggml_openshmem_context * ctx_openshmem, + struct ggml_cgraph * gf, + int n_layers) { + UNUSED(n_layers); + + const int openshmem_rank = ctx_openshmem->pe; + const int openshmem_size = ctx_openshmem->n_pes; + + // send the output data to the next node + if (openshmem_rank > 0) { + ggml_openshmem_tensor_send(ctx_openshmem, gf->nodes[gf->n_nodes - 1], (openshmem_rank + 1) % openshmem_size); + } +} diff --git a/ggml-oshmem.h b/ggml-oshmem.h new file mode 100644 index 000000000..fb953fb0f --- /dev/null +++ b/ggml-oshmem.h @@ -0,0 +1,43 @@ +#pragma once +#ifndef __LLAMA_CPP_GGML_OSHMEM_H__ +#define __LLAMA_CPP_GGML_OSHMEM_H__ + +struct ggml_context; +struct ggml_tensor; +struct ggml_cgraph; + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_openshmem_context; + +void ggml_openshmem_backend_init(void); +void ggml_openshmem_backend_free(void); + +struct ggml_openshmem_context * ggml_openshmem_init(void); +void ggml_openshmem_free(struct ggml_openshmem_context * ctx); + +int ggml_openshmem_rank(struct ggml_openshmem_context * ctx); + +void ggml_openshmem_eval_init( + struct ggml_openshmem_context * ctx_openshmem, + int * n_tokens, + int * n_past, + int * n_threads); + +void ggml_openshmem_graph_compute_pre( + struct ggml_openshmem_context * ctx_openshmem, + struct ggml_cgraph * gf, + int n_layers); + +void ggml_openshmem_graph_compute_post( + struct ggml_openshmem_context * ctx_openshmem, + struct ggml_cgraph * gf, + int n_layers); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/llama.cpp b/llama.cpp index edd2910b3..46318bed3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19,6 +19,9 @@ #ifdef GGML_USE_MPI # include "ggml-mpi.h" #endif +#ifdef GGML_USE_OPENSHMEM +# include "ggml-oshmem.h" +#endif #ifndef QK_K # ifdef GGML_QKK_64 # define QK_K 64