initial import

This commit is contained in:
ct-clmsn 2023-12-20 22:44:08 -05:00
parent 799fc22689
commit fcfe07f829
6 changed files with 1399 additions and 0 deletions

View file

@ -95,6 +95,7 @@ option(LLAMA_CLBLAST "llama: use CLBlast"
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_MPI "llama: use MPI" OFF)
option(LLAMA_OPENSHMEM "llama: use OpenSHMEM" OFF)
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
@ -344,6 +345,29 @@ if (LLAMA_MPI)
endif() endif()
endif() endif()
if (LLAMA_OPENSHMEM)
cmake_minimum_required(VERSION 3.10)
include(cmake/FindOpenSHMEM.cmake)
setup_openshmem()
if (OPENSHMEM_FOUND)
message(STATUS "OpenSHMEM found")
set(GGML_HEADERS_OPENSHMEM ggml-oshmem.h)
set(GGML_SOURCES_OPENSHMEM ggml-oshmem.c ggml-oshmem.h)
add_compile_definitions(GGML_USE_OPENSHMEM)
if (NOT MSVC)
add_compile_options(-Wno-cast-qual)
endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${OPENSHMEM_LDFLAGS})
string(REPLACE "-I" "" OPENSHMEM_CFLAGS ${OPENSHMEM_CFLAGS})
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${OPENSHMEM_CFLAGS})
else()
message(WARNING "OpenSHMEM not found")
endif()
endif()
if (LLAMA_CLBLAST) if (LLAMA_CLBLAST)
find_package(CLBlast) find_package(CLBlast)
if (CLBlast_FOUND) if (CLBlast_FOUND)
@ -722,6 +746,7 @@ add_library(ggml OBJECT
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
${GGML_SOURCES_OPENSHMEM} ${GGML_HEADERS_OPENSHMEM}
${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
) )

917
cmake/FindOpenSHMEM.cmake Normal file
View file

@ -0,0 +1,917 @@
# Copyright (c) 2019-2023 Ste||ar Group
#
# SPDX-License-Identifier: BSL-1.0
# Distributed under the Boost Software License, Version 1.0. (See accompanying
# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
#
macro(setup_openshmem)
if(NOT TARGET PkgConfig::OPENSHMEM)
set(OPENSHMEM_PC "")
find_package(MPI)
if (LLAMA_MPI AND MPI_C_FOUND)
set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:${MPI_LIBDIR}/pkgconfig")
set(OPENSHMEM_PC "oshmem")
pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC})
if(NOT OPENSHMEM_FOUND)
find_program(OSHMEM_INFO NAMES oshmem_info ompi_info REQUIRED)
if(NOT OSHMEM_INFO)
message(
FATAL_ERROR
"oshmem_info and/or ompi_info not found! pkg-config cannot find OpenMPI's `${OPENSHMEM_PC}.pc`"
)
endif()
set(OSHMEM_INFO_OUTPUT
"${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_stdout.log"
)
set(OSHMEM_INFO_ERROR
"${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_error.log"
)
execute_process(
COMMAND bash -c "${OSHMEM_INFO} --path libdir"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE OSHMEM_INFO_STATUS
OUTPUT_FILE ${OSHMEM_INFO_OUTPUT}
ERROR_FILE ${OSHMEM_INFO_ERROR}
)
if(OSHMEM_INFO_STATUS)
message(
FATAL_ERROR
"${OSHMEM_INFO} Failed! Program status code: ${OSHMEM_INFO_STATUS}"
)
endif()
file(READ ${OSHMEM_INFO_OUTPUT} OSHMEM_INFO_OUTPUT_CONTENT)
if(NOT DEFINED OSHMEM_INFO_OUTPUT_CONTENT)
message(
FATAL_ERROR
"${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_ERROR}\n${OSHMEM_INFO_OUTPUT_CONTENT}"
)
endif()
if("${OSHMEM_INFO_OUTPUT_CONTENT}" STREQUAL "")
message(
FATAL_ERROR
"${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_ERROR}\n${OSHMEM_INFO_OUTPUT_CONTENT}"
)
endif()
string(REGEX MATCH "(\/.*)" OSHMEM_LIBDIR_PATH
${OSHMEM_INFO_OUTPUT_CONTENT}
)
string(STRIP ${OSHMEM_LIBDIR_PATH} OSHMEM_LIBDIR_PATH)
set(ENV{PKG_CONFIG_PATH}
"$ENV{PKG_CONFIG_PATH}:${OSHMEM_LIBDIR_PATH}/pkgconfig"
)
pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC})
if(NOT OPENSHMEM_FOUND)
set(OSHMEM_INFO_INCOUTPUT
"${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_stdout_inc.log"
)
set(OSHMEM_INFO_INCERROR
"${CMAKE_CURRENT_SOURCE_DIR}/oshmem_info_error_inc.log"
)
execute_process(
COMMAND bash -c "${OSHMEM_INFO} --path incdir"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE OSHMEM_INFO_INCSTATUS
OUTPUT_FILE ${OSHMEM_INFO_INCOUTPUT}
ERROR_FILE ${OSHMEM_INFO_INCERROR}
)
if(OSHMEM_INFO_INCSTATUS)
message(
FATAL_ERROR
"${OSHMEM_INFO} Failed! Program status code: ${OSHMEM_INFO_INCSTATUS}"
)
endif()
file(READ ${OSHMEM_INFO_INCOUTPUT} OSHMEM_INFO_OUTPUT_INCCONTENT)
if(NOT DEFINED OSHMEM_INFO_OUTPUT_INCCONTENT)
message(
FATAL_ERROR
"${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_INCERROR}"
)
endif()
if("${OSHMEM_INFO_OUTPUT_INCCONTENT}" STREQUAL "")
message(
FATAL_ERROR
"${OSHMEM_INFO} Failed! Check: ${OSHMEM_INFO_INCERROR}\n${OSHMEM_INFO_OUTPUT_INCCONTENT}"
)
endif()
string(REGEX MATCH "(\/.*)" OSHMEM_INCDIR_PATH
${OSHMEM_INFO_OUTPUT_INCCONTENT}
)
string(STRIP ${OSHMEM_INCDIR_PATH} OSHMEM_INCDIR_PATH)
set(OPENSHMEM_CFLAGS
"-I${OSHMEM_INCDIR_PATH} -pthread -I${OSHMEM_LIBDIR_PATH}"
)
set(OPENSHMEM_LDFLAGS "-loshmem")
set(OPENSHMEM_LIBRARY_DIRS "${OSHMEM_LIBDIR_PATH}")
add_library(PkgConfig::OPENSHMEM INTERFACE IMPORTED GLOBAL)
set(OPENSHMEM_FOUND ON)
endif()
endif()
else()
include(cmake/FindOpenShmemPmi.cmake)
set(PMI_AUTOCONF_OPTS "")
if(NOT PMI_LIBRARY OR NOT PMI_FOUND)
set(PMI_AUTOCONF_OPTS "--enable-pmi-simple")
else()
string(REGEX MATCH "(^\/[^\/]+)" PMI_INCLUDE_DIR_ROOT_PATH
${PMI_INCLUDE_DIR}
)
string(REGEX MATCH "(^\/[^\/]+)" PMI_LIBRARY_ROOT_PATH ${PMI_LIBRARY})
set(PMI_AUTOCONF_OPTS
"--with-pmi=${PMI_INCLUDE_DIR_ROOT_PATH} --with-pmi-libdir=${PMI_LIBRARY_ROOT_PATH}"
)
endif()
set(OPENSHMEM_PC "osss-ucx")
pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC})
if(NOT OPENSHMEM_FOUND)
set(OPENSHMEM_PC "sandia-openshmem")
pkg_search_module(OPENSHMEM IMPORTED_TARGET GLOBAL ${OPENSHMEM_PC})
endif()
endif()
endif()
if(OPENSHMEM_CFLAGS)
set(IS_PARAM "0")
set(PARAM_FOUND "0")
set(NEWPARAM "")
set(IDX 0)
set(FLAG_LIST "")
foreach(X IN ITEMS ${OPENSHMEM_CFLAGS})
string(FIND "${X}" "--param" PARAM_FOUND)
if(NOT "${PARAM_FOUND}" EQUAL "-1")
set(IS_PARAM "1")
set(NEWPARAM "SHELL:${X}")
endif()
if("${PARAM_FOUND}" EQUAL "-1"
AND "${IS_PARAM}" EQUAL "0"
OR "${IS_PARAM}" EQUAL "-1"
)
list(APPEND FLAG_LIST "${X}")
set(IS_PARAM "0")
elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1")
list(APPEND FLAG_LIST "${NEWPARAM}
${X}"
)
set(NEWPARAM "")
set(IS_PARAM "0")
endif()
endforeach()
list(LENGTH OPENSHMEM_CFLAGS IDX)
foreach(X RANGE ${IDX})
list(POP_FRONT OPENSHMEM_CFLAGS NEWPARAM)
endforeach()
foreach(X IN ITEMS ${FLAG_LIST})
list(APPEND OPENSHMEM_CFLAGS "${X}")
endforeach()
endif()
if(OPENSHMEM_CFLAGS_OTHER)
set(IS_PARAM "0")
set(PARAM_FOUND "0")
set(NEWPARAM "")
set(IDX 0)
set(FLAG_LIST "")
foreach(X IN ITEMS ${OPENSHMEM_CFLAGS_OTHER})
string(FIND "${X}" "--param" PARAM_FOUND)
if(NOT "${PARAM_FOUND}" EQUAL "-1")
set(IS_PARAM "1")
set(NEWPARAM "SHELL:${X}")
endif()
if("${PARAM_FOUND}" EQUAL "-1"
AND "${IS_PARAM}" EQUAL "0"
OR "${IS_PARAM}" EQUAL "-1"
)
list(APPEND FLAG_LIST "${X}")
set(IS_PARAM "0")
elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1")
list(APPEND FLAG_LIST "${NEWPARAM}
${X}"
)
set(NEWPARAM "")
set(IS_PARAM "0")
endif()
endforeach()
list(LENGTH OPENSHMEM_CFLAGS_OTHER IDX)
foreach(X RANGE ${IDX})
list(POP_FRONT OPENSHMEM_CFLAGS_OTHER NEWPARAM)
endforeach()
foreach(X IN ITEMS ${FLAG_LIST})
list(APPEND OPENSHMEM_CFLAGS_OTHER "${X}")
endforeach()
endif()
if(OPENSHMEM_LDFLAGS)
set(IS_PARAM "0")
set(PARAM_FOUND "0")
set(NEWPARAM "")
set(IDX 0)
set(DIRIDX 0)
set(SKIP 0)
set(FLAG_LIST "")
set(DIR_LIST "")
set(LIB_LIST "")
foreach(X IN ITEMS ${OPENSHMEM_LDFLAGS})
string(FIND "${X}" "--param" PARAM_FOUND)
string(FIND "${X}" "-lsma" IDX)
string(FIND "${X}" "-l" LIDX)
string(FIND "${X}" "-L" DIRIDX)
string(FIND "${X}" "-Wl" SKIP)
if("${SKIP}" EQUAL "-1")
if(NOT "${PARAM_FOUND}" EQUAL "-1")
set(IS_PARAM "1")
set(NEWPARAM "SHELL:${X}")
endif()
if("${PARAM_FOUND}" EQUAL "-1"
AND "${IDX}" EQUAL "-1"
AND "${IS_PARAM}" EQUAL "0"
OR "${IS_PARAM}" EQUAL "-1"
)
list(APPEND FLAG_LIST "${X}")
set(IS_PARAM "0")
elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1")
list(APPEND FLAG_LIST "${NEWPARAM}
${X}"
)
set(NEWPARAM "")
set(IS_PARAM "0")
elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1")
set(TMPSTR "")
string(REPLACE "-l" "" TMPSTR "${X}")
list(APPEND LIB_LIST "${TMPSTR}")
set(IDX 0)
elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1")
list(APPEND FLAG_LIST "${X}")
endif()
if(NOT "${DIRIDX}" EQUAL "-1")
set(TMPSTR "")
string(REPLACE "-L" "" TMPSTR "${X}")
list(APPEND DIR_LIST "${TMPSTR}")
endif()
endif()
endforeach()
set(IDX 0)
list(LENGTH LIB_LIST IDX)
if(NOT "${IDX}" EQUAL "0")
set(IDX 0)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(NEWLINK "SHELL:-Wl,--whole-archive
"
)
foreach(X IN ITEMS ${LIB_LIST})
set(DIRSTR "")
string(REPLACE ";" "
" DIRSTR "${DIR_LIST}"
)
foreach(Y IN ITEMS ${DIR_LIST})
find_library(
FOUND_LIB
NAMES ${X} "lib${X}" "lib${X}.a"
PATHS ${Y}
HINTS ${Y} NO_CACHE
NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH
)
list(LENGTH FOUND_LIB IDX)
if(NOT "${IDX}" EQUAL "0")
string(APPEND NEWLINK "${FOUND_LIB}")
set(FOUND_LIB "")
endif()
endforeach()
endforeach()
string(APPEND NEWLINK "
-Wl,--no-whole-archive"
)
string(FIND "SHELL:-Wl,--whole-archive
-Wl,--no-whole-archive" "${NEWLINK}" IDX
)
if("${IDX}" EQUAL "-1")
list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}")
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
if(APPLE)
set(NEWLINK "SHELL:-Wl,-force_load,")
else()
set(NEWLINK "SHELL:
"
)
endif()
foreach(X IN ITEMS ${LIB_LIST})
set(DIRSTR "")
string(REPLACE ";" "
" DIRSTR "${DIR_LIST}"
)
foreach(Y IN ITEMS ${DIR_LIST})
find_library(
FOUND_LIB
NAMES ${X} "lib${X}" "lib${X}.a"
PATHS ${Y}
HINTS ${Y} NO_CACHE
NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH
)
list(LENGTH FOUND_LIB IDX)
if(NOT "${IDX}" EQUAL "0")
string(APPEND NEWLINK "${FOUND_LIB}")
set(FOUND_LIB "")
endif()
endforeach()
endforeach()
string(FIND "SHELL:" "${NEWLINK}" IDX)
if("${IDX}" EQUAL "-1")
list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}")
endif()
endif()
endif()
endif()
if(OPENSHMEM_LDFLAGS_OTHER)
unset(FOUND_LIB)
set(IS_PARAM "0")
set(PARAM_FOUND "0")
set(NEWPARAM "")
set(SKIP 0)
set(IDX 0)
set(DIRIDX 0)
set(FLAG_LIST "")
set(DIR_LIST "")
set(LIB_LIST "")
foreach(X IN ITEMS ${OPENSHMEM_LDFLAGS_OTHER})
string(FIND "${X}" "--param" PARAM_FOUND)
string(FIND "${X}" "-lsma" IDX)
string(FIND "${X}" "-L" DIRIDX)
string(FIND "${X}" "-Wl" SKIP)
if("${SKIP}" EQUAL "-1")
if(NOT "${PARAM_FOUND}" EQUAL "-1")
set(IS_PARAM "1")
set(NEWPARAM "SHELL:${X}")
endif()
if("${PARAM_FOUND}" EQUAL "-1"
AND "${IDX}" EQUAL "-1"
AND "${IS_PARAM}" EQUAL "0"
OR "${IS_PARAM}" EQUAL "-1"
)
list(APPEND FLAG_LIST "${X}")
set(IS_PARAM "0")
elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1")
list(APPEND FLAG_LIST "${NEWPARAM}
${X}"
)
set(NEWPARAM "")
set(IS_PARAM "0")
elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1")
set(TMPSTR "")
string(REPLACE "-l" "" TMPSTR "${X}")
list(APPEND LIB_LIST "${TMPSTR}")
set(IDX 0)
elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1")
list(APPEND FLAG_LIST "${X}")
endif()
if(NOT "${DIRIDX}" EQUAL "-1")
set(TMPSTR "")
string(REPLACE "-L" "" TMPSTR "${X}")
list(APPEND DIR_LIST "${TMPSTR}")
endif()
endif()
endforeach()
set(IDX 0)
list(LENGTH OPENSHMEM_LDFLAGS_OTHER IDX)
foreach(X RANGE ${IDX})
list(POP_FRONT OPENSHMEM_LDFLAGS_OTHER NEWPARAM)
endforeach()
foreach(X IN ITEMS ${FLAG_LIST})
list(APPEND OPENSHMEM_LDFLAGS_OTHER "${X}")
endforeach()
set(IDX 0)
list(LENGTH LIB_LIST IDX)
if(NOT "${IDX}" EQUAL "0")
set(IDX 0)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(NEWLINK "SHELL:-Wl,--whole-archive
"
)
foreach(X IN ITEMS ${LIB_LIST})
set(DIRSTR "")
string(REPLACE ";" "
" DIRSTR "${DIR_LIST}"
)
foreach(Y IN ITEMS ${DIR_LIST})
find_library(
FOUND_LIB
NAMES ${X} "lib${X}" "lib${X}.a"
PATHS ${Y}
HINTS ${Y} NO_CACHE
NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH
)
list(LENGTH FOUND_LIB IDX)
if(NOT "${IDX}" EQUAL "0")
string(APPEND NEWLINK "${FOUND_LIB}")
set(FOUND_LIB "")
endif()
endforeach()
endforeach()
string(APPEND NEWLINK "
-Wl,--no-whole-archive"
)
string(FIND "SHELL:-Wl,--whole-archive
-Wl,--no-whole-archive" "${NEWLINK}" IDX
)
if("${IDX}" EQUAL "-1")
list(APPEND OPENSHMEM_LDFLAGS_OTHER "${NEWLINK}")
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
if(APPLE)
set(NEWLINK "SHELL:-Wl,-force_load,")
else()
set(NEWLINK "SHELL:
"
)
endif()
foreach(X IN ITEMS ${LIB_LIST})
set(DIRSTR "")
string(REPLACE ";" "
" DIRSTR "${DIR_LIST}"
)
foreach(Y IN ITEMS ${DIR_LIST})
find_library(
FOUND_LIB
NAMES ${X} "lib${X}" "lib${X}.a"
PATHS ${Y}
HINTS ${Y} NO_CACHE
NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH
)
list(LENGTH FOUND_LIB IDX)
if(NOT "${IDX}" EQUAL "0")
string(APPEND NEWLINK "${FOUND_LIB}")
set(FOUND_LIB "")
endif()
endforeach()
endforeach()
string(FIND "SHELL:" "${NEWLINK}" IDX)
if("${IDX}" EQUAL "-1")
list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}")
endif()
endif()
endif()
endif()
if(OPENSHMEM_STATIC_CFLAGS)
set(IS_PARAM "0")
set(PARAM_FOUND "0")
set(NEWPARAM "")
set(IDX 0)
set(FLAG_LIST "")
foreach(X IN ITEMS ${OPENSHMEM_STATIC_CFLAGS})
string(FIND "${X}" "--param" PARAM_FOUND)
if(NOT "${PARAM_FOUND}" EQUAL "-1")
set(IS_PARAM "1")
set(NEWPARAM "SHELL:${X}")
endif()
if("${PARAM_FOUND}" EQUAL "-1"
AND "${IS_PARAM}" EQUAL "0"
OR "${IS_PARAM}" EQUAL "-1"
)
list(APPEND FLAG_LIST "${X}")
set(IS_PARAM "0")
elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1")
list(APPEND FLAG_LIST "${NEWPARAM}
${X}"
)
set(NEWPARAM "")
set(IS_PARAM "0")
endif()
endforeach()
list(LENGTH OPENSHMEM_STATIC_CFLAGS IDX)
foreach(X RANGE ${IDX})
list(POP_FRONT OPENSHMEM_STATIC_CFLAGS NEWPARAM)
endforeach()
foreach(X IN ITEMS ${FLAG_LIST})
list(APPEND OPENSHMEM_STATIC_CFLAGS "${X}")
endforeach()
endif()
if(OPENSHMEM_STATIC_CFLAGS_OTHER)
set(IS_PARAM "0")
set(PARAM_FOUND "0")
set(NEWPARAM "")
set(IDX 0)
set(FLAG_LIST "")
foreach(X IN ITEMS ${OPENSHMEM_STATIC_CFLAGS_OTHER})
string(FIND "${X}" "--param" PARAM_FOUND)
if(NOT "${PARAM_FOUND}" EQUAL "-1")
set(IS_PARAM "1")
set(NEWPARAM "SHELL:${X}")
endif()
if("${PARAM_FOUND}" EQUAL "-1"
AND "${IS_PARAM}" EQUAL "0"
OR "${IS_PARAM}" EQUAL "-1"
)
list(APPEND FLAG_LIST "${X}")
set(IS_PARAM "0")
elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1")
list(APPEND FLAG_LIST "${NEWPARAM}
${X}"
)
set(NEWPARAM "")
set(IS_PARAM "0")
endif()
endforeach()
list(LENGTH OPENSHMEM_STATIC_CFLAGS_OTHER IDX)
foreach(X RANGE ${IDX})
list(POP_FRONT OPENSHMEM_STATIC_CFLAGS_OTHER NEWPARAM)
endforeach()
foreach(X IN ITEMS ${FLAG_LIST})
list(APPEND OPENSHMEM_STATIC_CFLAGS_OTHER "${X}")
endforeach()
endif()
if(OPENSHMEM_STATIC_LDFLAGS)
unset(FOUND_LIB)
set(IS_PARAM "0")
set(PARAM_FOUND "0")
set(NEWPARAM "")
set(SKIP 0)
set(IDX 0)
set(DIRIDX 0)
set(FLAG_LIST "")
set(DIR_LIST "")
set(LIB_LIST "")
foreach(X IN ITEMS ${OPENSHMEM_STATIC_LDFLAGS})
string(FIND "${X}" "--param" PARAM_FOUND)
if("${HPX_WITH_PARCELPORT_OPENSHMEM_CONDUIT}" STREQUAL "mpi")
string(FIND "${X}" "-loshmem" IDX)
else()
string(FIND "${X}" "-lsma" IDX)
endif()
string(FIND "${X}" "-L" DIRIDX)
string(FIND "${X}" "-Wl" SKIP)
if("${SKIP}" EQUAL "-1")
if(NOT "${PARAM_FOUND}" EQUAL "-1")
set(IS_PARAM "1")
set(NEWPARAM "SHELL:${X}")
endif()
if("${PARAM_FOUND}" EQUAL "-1"
AND "${IDX}" EQUAL "-1"
AND "${IS_PARAM}" EQUAL "0"
OR "${IS_PARAM}" EQUAL "-1"
)
list(APPEND FLAG_LIST "${X}")
set(IS_PARAM "0")
elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1")
list(APPEND FLAG_LIST "${NEWPARAM}
${X}"
)
set(NEWPARAM "")
set(IS_PARAM "0")
elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1")
set(TMPSTR "")
string(REPLACE "-l" "" TMPSTR "${X}")
list(APPEND LIB_LIST "${TMPSTR}")
set(IDX 0)
elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1")
list(APPEND FLAG_LIST "${X}")
endif()
if(NOT "${DIRIDX}" EQUAL "-1")
set(TMPSTR "")
string(REPLACE "-L" "" TMPSTR "${X}")
list(APPEND DIR_LIST "${TMPSTR}")
endif()
endif()
endforeach()
set(IDX 0)
list(LENGTH OPENSHMEM_STATIC_LDFLAGS IDX)
foreach(X RANGE ${IDX})
list(POP_FRONT OPENSHMEM_STATIC_LDFLAGS NEWPARAM)
endforeach()
foreach(X IN ITEMS ${FLAG_LIST})
list(APPEND OPENSHMEM_STATIC_LDFLAGS "${X}")
endforeach()
set(IDX 0)
list(LENGTH LIB_LIST IDX)
if(NOT "${IDX}" EQUAL "0")
set(IDX 0)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(NEWLINK "SHELL:-Wl,--whole-archive
"
)
foreach(X IN ITEMS ${LIB_LIST})
set(DIRSTR "")
string(REPLACE ";" "
" DIRSTR "${DIR_LIST}"
)
foreach(Y IN ITEMS ${DIR_LIST})
find_library(
FOUND_LIB
NAMES ${X} "lib${X}" "lib${X}.a"
PATHS ${Y}
HINTS ${Y} NO_CACHE
NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH
)
list(LENGTH FOUND_LIB IDX)
if(NOT "${IDX}" EQUAL "0")
string(APPEND NEWLINK "${FOUND_LIB}")
set(FOUND_LIB "")
endif()
endforeach()
endforeach()
string(APPEND NEWLINK "
-Wl,--no-whole-archive"
)
string(FIND "SHELL:-Wl,--whole-archive
-Wl,--no-whole-archive" "${NEWLINK}" IDX
)
if("${IDX}" EQUAL "-1")
list(APPEND OPENSHMEM_STATIC_LDFLAGS "${NEWLINK}")
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
if(APPLE)
set(NEWLINK "SHELL:-Wl,-force_load,")
else()
set(NEWLINK "SHELL:
"
)
endif()
foreach(X IN ITEMS ${LIB_LIST})
set(DIRSTR "")
string(REPLACE ";" "
" DIRSTR "${DIR_LIST}"
)
foreach(Y IN ITEMS ${DIR_LIST})
find_library(
FOUND_LIB
NAMES ${X} "lib${X}" "lib${X}.a"
PATHS ${Y}
HINTS ${Y} NO_CACHE
NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH
)
list(LENGTH FOUND_LIB IDX)
if(NOT "${IDX}" EQUAL "0")
string(APPEND NEWLINK "${FOUND_LIB}")
set(FOUND_LIB "")
endif()
endforeach()
endforeach()
string(FIND "SHELL:" "${NEWLINK}" IDX)
if("${IDX}" EQUAL "-1")
list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}")
endif()
endif()
endif()
endif()
if(OPENSHMEM_STATIC_LDFLAGS_OTHER)
unset(FOUND_LIB)
set(IS_PARAM "0")
set(PARAM_FOUND "0")
set(NEWPARAM "")
set(SKIP 0)
set(IDX 0)
set(DIRIDX 0)
set(FLAG_LIST "")
set(DIR_LIST "")
set(LIB_LIST "")
foreach(X IN ITEMS ${OPENSHMEM_STATIC_LDFLAGS_OTHER})
string(FIND "${X}" "--param" PARAM_FOUND)
if("${HPX_WITH_PARCELPORT_OPENSHMEM_CONDUIT}" STREQUAL "mpi")
string(FIND "${X}" "-loshmem" IDX)
else()
string(FIND "${X}" "-lsma" IDX)
endif()
string(FIND "${X}" "-L" DIRIDX)
string(FIND "${X}" "-Wl" SKIP)
if("${SKIP}" EQUAL "-1")
if(NOT "${PARAM_FOUND}" EQUAL "-1")
set(IS_PARAM "1")
set(NEWPARAM "SHELL:${X}")
endif()
if("${PARAM_FOUND}" EQUAL "-1"
AND "${IDX}" EQUAL "-1"
AND "${IS_PARAM}" EQUAL "0"
OR "${IS_PARAM}" EQUAL "-1"
)
list(APPEND FLAG_LIST "${X}")
set(IS_PARAM "0")
elseif("${PARAM_FOUND}" EQUAL "-1" AND "${IS_PARAM}" EQUAL "1")
list(APPEND FLAG_LIST "${NEWPARAM}
${X}"
)
set(NEWPARAM "")
set(IS_PARAM "0")
elseif(NOT "${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1")
set(TMPSTR "")
string(REPLACE "-l" "" TMPSTR "${X}")
list(APPEND LIB_LIST "${TMPSTR}")
set(IDX 0)
elseif("${IDX}" EQUAL "-1" AND NOT "${LIDX}" EQUAL "-1")
list(APPEND FLAG_LIST "${X}")
endif()
if(NOT "${DIRIDX}" EQUAL "-1")
set(TMPSTR "")
string(REPLACE "-L" "" TMPSTR "${X}")
list(APPEND DIR_LIST "${TMPSTR}")
endif()
endif()
endforeach()
set(IDX 0)
list(LENGTH OPENSHMEM_STATIC_LDFLAGS_OTHER IDX)
foreach(X RANGE ${IDX})
list(POP_FRONT OPENSHMEM_STATIC_LDFLAGS_OTHER NEWPARAM)
endforeach()
foreach(X IN ITEMS ${FLAG_LIST})
list(APPEND OPENSHMEM_STATIC_LDFLAGS_OTHER "${X}")
endforeach()
set(IDX 0)
list(LENGTH LIB_LIST IDX)
if(NOT "${IDX}" EQUAL "0")
set(IDX 0)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(NEWLINK "SHELL:-Wl,--whole-archive
"
)
foreach(X IN ITEMS ${LIB_LIST})
set(DIRSTR "")
string(REPLACE ";" "
" DIRSTR "${DIR_LIST}"
)
foreach(Y IN ITEMS ${DIR_LIST})
find_library(
FOUND_LIB
NAMES ${X} "lib${X}" "lib${X}.a"
PATHS ${Y}
HINTS ${Y} NO_CACHE
NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH
)
list(LENGTH FOUND_LIB IDX)
message(STATUS "${FOUND_LIB}
${X}"
)
if(NOT "${IDX}" EQUAL "0")
string(APPEND NEWLINK "${FOUND_LIB}")
set(FOUND_LIB "")
endif()
endforeach()
endforeach()
string(APPEND NEWLINK "
-Wl,--no-whole-archive"
)
string(FIND "SHELL:-Wl,--whole-archive
-Wl,--no-whole-archive" "${NEWLINK}" IDX
)
if("${IDX}" EQUAL "-1")
list(APPEND OPENSHMEM_STATIC_LDFLAGS_OTHER "${NEWLINK}")
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
if(APPLE)
set(NEWLINK "SHELL:-Wl,-force_load,")
else()
set(NEWLINK "SHELL:
"
)
endif()
foreach(X IN ITEMS ${LIB_LIST})
set(DIRSTR "")
string(REPLACE ";" "
" DIRSTR "${DIR_LIST}"
)
foreach(Y IN ITEMS ${DIR_LIST})
find_library(
FOUND_LIB
NAMES ${X} "lib${X}" "lib${X}.a"
PATHS ${Y}
HINTS ${Y} NO_CACHE
NO_CMAKE_FIND_ROOT_PATH NO_DEFAULT_PATH
)
list(LENGTH FOUND_LIB IDX)
if(NOT "${IDX}" EQUAL "0")
string(APPEND NEWLINK "${FOUND_LIB}")
set(FOUND_LIB "")
endif()
endforeach()
endforeach()
string(FIND "SHELL:" "${NEWLINK}" IDX)
if("${IDX}" EQUAL "-1")
list(APPEND OPENSHMEM_LDFLAGS "${NEWLINK}")
endif()
endif()
endif()
endif()
if(OPENSHMEM_DIR)
list(TRANSFORM OPENSHMEM_CFLAGS
REPLACE "${OPENSHMEM_DIR}/install"
"$<BUILD_INTERFACE:${OPENSHMEM_DIR}/install>"
)
list(TRANSFORM OPENSHMEM_LDFLAGS
REPLACE "${OPENSHMEM_DIR}/install"
"$<BUILD_INTERFACE:${OPENSHMEM_DIR}/install>"
)
list(TRANSFORM OPENSHMEM_LIBRARY_DIRS
REPLACE "${OPENSHMEM_DIR}/install"
"$<BUILD_INTERFACE:${OPENSHMEM_DIR}/install>"
)
message(STATUS "OPENSHMEM_CFLAGS:\t${OPENSHMEM_CFLAGS}")
message(STATUS "OPENSHMEM_LDFLAGS:\t${OPENSHMEM_LDFLAGS}")
message(STATUS "OPENSHMEM_LIBRARY_DIRS:\t${OPENSHMEM_LIBRARY_DIRS}")
set_target_properties(
PkgConfig::OPENSHMEM PROPERTIES INTERFACE_COMPILE_OPTIONS
"${OPENSHMEM_CFLAGS}"
)
set_target_properties(
PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_OPTIONS
"${OPENSHMEM_LDFLAGS}"
)
set_target_properties(
PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_DIRECTORIES
"${OPENSHMEM_LIBRARY_DIRS}"
)
set(OPENSHMEM_FOUND ON)
else()
message(STATUS "OPENSHMEM_CFLAGS:\t${OPENSHMEM_CFLAGS}")
message(STATUS "OPENSHMEM_LDFLAGS:\t${OPENSHMEM_LDFLAGS}")
message(STATUS "OPENSHMEM_LIBRARY_DIRS:\t${OPENSHMEM_LIBRARY_DIRS}")
set_target_properties(
PkgConfig::OPENSHMEM PROPERTIES INTERFACE_COMPILE_OPTIONS
"${OPENSHMEM_CFLAGS}"
)
set_target_properties(
PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_OPTIONS
"${OPENSHMEM_LDFLAGS}"
)
set_target_properties(
PkgConfig::OPENSHMEM PROPERTIES INTERFACE_LINK_DIRECTORIES
"${OPENSHMEM_LIBRARY_DIRS}"
)
set(OPENSHMEM_FOUND ON)
endif()
endmacro()

View file

@ -0,0 +1,65 @@
# Copyright (c) 2023 Christopher Taylor
#
# SPDX-License-Identifier: BSL-1.0
# Distributed under the Boost Software License, Version 1.0. (See accompanying
# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
find_package(PkgConfig QUIET)
# look for cray pmi...
pkg_check_modules(PC_PMI_CRAY QUIET cray-pmi)
# look for the rest if we couldn't find the cray package
if(NOT PC_PMI_CRAY_FOUND)
pkg_check_modules(PC_PMI QUIET pmi)
endif()
find_path(
PMI_INCLUDE_DIR pmi2.h
HINTS ${PMI_ROOT}
ENV
PMI_ROOT
${PMI_DIR}
ENV
PMI_DIR
${PC_PMI_CRAY_INCLUDEDIR}
${PC_PMI_CRAY_INCLUDE_DIRS}
${PC_PMI_INCLUDEDIR}
${PC_PMI_INCLUDE_DIRS}
PATH_SUFFIXES include
)
find_library(
PMI_LIBRARY
NAMES pmi
HINTS ${PMI_ROOT}
ENV
PMI_ROOT
${PC_PMI_CRAY_LIBDIR}
${PC_PMI_CRAY_LIBRARY_DIRS}
${PC_PMI_LIBDIR}
${PC_PMI_LIBRARY_DIRS}
PATH_SUFFIXES lib lib64
)
# Set PMI_ROOT in case the other hints are used
if(PMI_ROOT)
# The call to file is for compatibility with windows paths
file(TO_CMAKE_PATH ${PMI_ROOT} PMI_ROOT)
elseif("$ENV{PMI_ROOT}")
file(TO_CMAKE_PATH $ENV{PMI_ROOT} PMI_ROOT)
else()
file(TO_CMAKE_PATH "${PMI_INCLUDE_DIR}" PMI_INCLUDE_DIR)
string(REPLACE "/include" "" PMI_ROOT "${PMI_INCLUDE_DIR}")
endif()
if(NOT PMI_LIBRARY OR NOT PMI_INCLUDE_DIR)
set(PMI_FOUND=OFF)
return()
endif()
# hpx_error( "PMI_LIBRARY OR PMI_INCLUDE_DIR not found, please install PMI or
# set \ the right PMI_ROOT path" )
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(PMI DEFAULT_MSG PMI_LIBRARY PMI_INCLUDE_DIR)
mark_as_advanced(PMI_ROOT PMI_LIBRARY PMI_INCLUDE_DIR)

346
ggml-oshmem.c Normal file
View file

@ -0,0 +1,346 @@
#include "ggml-oshmem.h"
#include "ggml.h"
#include <shmem.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define UNUSED GGML_UNUSED
#define OPENSHMEM_SYMMETRIC_BUFFER_SIZE 4096
struct ggml_openshmem_context {
int pe;
int n_pes;
int64_t symmetric_buffer_size;
int64_t symmetric_comm_structure_size;
uint8_t * symmetric_comm_structure;
long * recv_signal;
};
void ggml_openshmem_backend_init(void) {
shmem_init();
}
void ggml_openshmem_backend_free(void) {
shmem_finalize();
}
struct ggml_openshmem_context * ggml_openshmem_init(void) {
struct ggml_openshmem_context * ctx = calloc(1, sizeof(struct ggml_openshmem_context));
ctx->pe = shmem_my_pe();
ctx->n_pes = shmem_n_pes();
/*
* makes a symmetric heap allocation on all processing elements (processes running this SPMD program)
*
* below is a struct representing the layout of the symmetric allocation:
*
* {
* int64_t offset_in_buffer,
* int64_t length_in_buffer,
* uint8_t buffer[shmem_npes()][OPENSHMEM_SYMMETRIC_BUFFER_SIZE]
* }
*
*/
ctx->symmetric_buffer_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE;
ctx->symmetric_comm_structure_size = OPENSHMEM_SYMMETRIC_BUFFER_SIZE + sizeof(int64_t) + sizeof(int64_t);
ctx->symmetric_comm_structure = (uint8_t*)shmem_calloc(1, ctx->n_pes*ctx->symmetric_comm_structure_size);
/*
* uint8_t signal_byte[shmem_npes()];
*/
ctx->recv_signal = (long*)shmem_calloc(1, ctx->n_pes*sizeof(long));
return ctx;
}
void ggml_openshmem_free(struct ggml_openshmem_context * ctx) {
free(ctx);
}
int ggml_openshmem_pe(struct ggml_openshmem_context * ctx) {
return ctx->pe;
}
void ggml_openshmem_eval_init(
struct ggml_openshmem_context * ctx_openshmem,
int * n_tokens,
int * n_past,
int * n_threads) {
UNUSED(ctx_openshmem);
// synchronize the worker node parameters with the root node
shmem_barrier_all();
shmem_broadcast(SHMEM_TEAM_WORLD, n_tokens, n_tokens, 1, 0);
shmem_broadcast(SHMEM_TEAM_WORLD, n_past, n_tokens, 1, 0);
shmem_broadcast(SHMEM_TEAM_WORLD, n_threads, n_tokens, 1, 0);
shmem_quiet();
}
static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) {
struct ggml_tensor * t = ggml_graph_get_tensor(gf, name);
if (t == NULL) {
fprintf(stderr, "%s: tensor %s not found\n", __func__, name);
return -1;
}
for (int i = 0; i < gf->n_nodes; i++) {
if (gf->nodes[i] == t) {
return i;
}
}
fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name);
return -1;
}
static void ggml_openshmem_tensor_send(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int dst_pe) {
const int64_t symmetric_comm_structure_size =
ctx->symmetric_comm_structure_size;
uint8_t * dst_symmetric_comm_structure =
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*ctx->pe);
int64_t * dst_symmetric_comm_offset =
(int64_t*)(dst_symmetric_comm_structure);
int64_t * dst_symmetric_comm_length =
((int64_t*)dst_symmetric_comm_offset)+sizeof(int64_t);
uint8_t * dst_symmetric_comm_buffer =
((uint8_t*)dst_symmetric_comm_length)+sizeof(int64_t);
long * dst_recv_signal =
ctx->recv_signal+dst_pe;
long * my_recv_signal =
ctx->recv_signal+ctx->pe;
const int64_t nelements = ggml_nelements(t);
int64_t xmt_size = 0;
switch (t->type) {
case GGML_TYPE_I32:
xmt_size = nelements * sizeof(int32_t);
break;
case GGML_TYPE_F32:
xmt_size = nelements * sizeof(int32_t);
break;
default: GGML_ASSERT(false && "not implemented");
}
int64_t count[2] = { (xmt_size / OPENSHMEM_SYMMETRIC_BUFFER_SIZE), 1 };
const int64_t total_loop_count = count[ count[0] == 0 ];
int64_t xmt_amount [2] = { OPENSHMEM_SYMMETRIC_BUFFER_SIZE, xmt_size - (OPENSHMEM_SYMMETRIC_BUFFER_SIZE * count[0]) };
int64_t xmt_byte_offset = 0;
int64_t xmt_byte_amount = 0;
memcpy(dst_symmetric_comm_offset, &total_loop_count, sizeof(int64_t));
shmem_put_signal(
dst_symmetric_comm_offset,
dst_symmetric_comm_offset,
sizeof(int64_t),
dst_recv_signal,
1,
SHMEM_SIGNAL_SET,
dst_pe
);
shmem_wait_until(
my_recv_signal,
SHMEM_CMP_EQ,
1
);
(*my_recv_signal) = 0;
xmt_byte_amount = xmt_amount[0 == (total_loop_count-1)];
for(int32_t i = 0; i < total_loop_count; ++i) {
memcpy(dst_symmetric_comm_offset, &xmt_byte_offset, sizeof(int64_t));
memcpy(dst_symmetric_comm_length, &xmt_byte_amount, sizeof(int64_t));
memcpy(dst_symmetric_comm_buffer, ((uint8_t*)t->data)+xmt_byte_offset, xmt_byte_amount);
shmem_put_signal(
dst_symmetric_comm_structure,
dst_symmetric_comm_structure,
symmetric_comm_structure_size,
dst_recv_signal,
1,
SHMEM_SIGNAL_SET,
dst_pe
);
shmem_wait_until(
my_recv_signal,
SHMEM_CMP_EQ,
1
);
(*my_recv_signal) = 0;
xmt_byte_offset += xmt_byte_amount;
xmt_amount[1] -= xmt_byte_amount;
xmt_byte_amount = xmt_amount[i == (total_loop_count-1)];
}
shmem_fence();
}
static void ggml_openshmem_tensor_recv(struct ggml_openshmem_context * ctx, struct ggml_tensor * t, int src_pe) {
uint8_t * src_symmetric_comm_structure =
((uint8_t*)ctx->symmetric_comm_structure)+(ctx->symmetric_comm_structure_size*src_pe);
int64_t * src_symmetric_comm_offset =
(int64_t*)(src_symmetric_comm_structure);
int64_t * src_symmetric_comm_length =
((int64_t*)src_symmetric_comm_offset)+sizeof(int64_t);
uint8_t * src_symmetric_comm_buffer =
((uint8_t*)src_symmetric_comm_length)+sizeof(int64_t);
long * src_recv_signal =
ctx->recv_signal+src_pe;
long* my_recv_signal =
ctx->recv_signal+ctx->pe;
int64_t total_loop_count = 0;
shmem_wait_until(my_recv_signal, SHMEM_CMP_EQ, 1);
(*my_recv_signal) = 0;
memcpy(src_symmetric_comm_offset, &total_loop_count, sizeof(int64_t));
shmem_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe);
for(int32_t i = 0; i < total_loop_count; ++i) {
shmem_wait_until(my_recv_signal, SHMEM_CMP_EQ, 1);
(*my_recv_signal) = 0;
memcpy(
((uint8_t*)t->data)+(*src_symmetric_comm_offset),
src_symmetric_comm_buffer+(*src_symmetric_comm_offset),
(*src_symmetric_comm_length)
);
shmem_put_signal(src_symmetric_comm_structure, src_symmetric_comm_structure, 0, src_recv_signal, 1, SHMEM_SIGNAL_SET, src_pe);
}
shmem_fence();
}
// TODO: there are many improvements that can be done to this implementation
void ggml_openshmem_graph_compute_pre(
struct ggml_openshmem_context * ctx_openshmem,
struct ggml_cgraph * gf,
int n_layers) {
const int openshmem_rank = ctx_openshmem->pe;
const int openshmem_size = ctx_openshmem->n_pes;
struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens");
if (inp_tokens == NULL) {
fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__);
return;
}
struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0");
if (inp0 == NULL) {
fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__);
return;
}
GGML_ASSERT(inp0 == gf->nodes[0]);
// distribute the compute graph into slices across the MPI nodes
//
// the main node (0) processes the last layers + the remainder of the compute graph
// and is responsible to pass the input tokens to the first node (1)
//
// node 1: [( 0) * n_per_node, ( 1) * n_per_node)
// node 2: [( 1) * n_per_node, ( 2) * n_per_node)
// ...
// node n-1: [(n-2) * n_per_node, (n-1) * n_per_node)
// node 0: [(n-1) * n_per_node, n_nodes)
//
{
struct ggml_tensor * input_tokens[2] = { inp_tokens, inp0 };
if (openshmem_rank > 0) {
ggml_openshmem_tensor_recv(ctx_openshmem, input_tokens[openshmem_rank == 1], openshmem_rank-1);
}
else if (openshmem_size > 1) {
// node 0 sends the input tokens to node 1
ggml_openshmem_tensor_send(ctx_openshmem, input_tokens[0], 1);
// recv the output data from the last node
ggml_openshmem_tensor_recv(ctx_openshmem, input_tokens[1], openshmem_size - 1);
}
}
{
const int n_per_node = (n_layers + (openshmem_size - 1)) / openshmem_size;
const int openshmem_idx = openshmem_rank > 0 ? openshmem_rank - 1 : openshmem_size - 1;
const int il0 = (openshmem_idx + 0) * n_per_node;
const int il1 = MIN(n_layers, (openshmem_idx + 1) * n_per_node);
char name_l0[GGML_MAX_NAME];
char name_l1[GGML_MAX_NAME];
snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0);
snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1);
const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0);
const int idx_l1 = openshmem_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes;
if (idx_l0 < 0 || idx_l1 < 0) {
fprintf(stderr, "%s: layer input nodes not found\n", __func__);
return;
}
// attach the input data to all nodes that need it
// TODO: not great - should be able to do this without modifying the compute graph (see next TODO below)
for (int i = idx_l0; i < idx_l1; i++) {
if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) {
gf->nodes[i]->src[0] = inp0;
}
if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) {
gf->nodes[i]->src[1] = inp0;
}
}
// TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
for (int i = 1; i < idx_l1 - idx_l0; i++) {
gf->nodes[i] = gf->nodes[idx_l0 + i];
gf->grads[i] = gf->grads[idx_l0 + i];
}
// the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node
if (openshmem_idx != 0) {
gf->nodes[0]->op = GGML_OP_NONE;
}
gf->n_nodes = idx_l1 - idx_l0;
//fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, openshmem_rank, gf->n_nodes, il0, il1);
}
}
void ggml_openshmem_graph_compute_post(
struct ggml_openshmem_context * ctx_openshmem,
struct ggml_cgraph * gf,
int n_layers) {
UNUSED(n_layers);
const int openshmem_rank = ctx_openshmem->pe;
const int openshmem_size = ctx_openshmem->n_pes;
// send the output data to the next node
if (openshmem_rank > 0) {
ggml_openshmem_tensor_send(ctx_openshmem, gf->nodes[gf->n_nodes - 1], (openshmem_rank + 1) % openshmem_size);
}
}

43
ggml-oshmem.h Normal file
View file

@ -0,0 +1,43 @@
#pragma once
#ifndef __LLAMA_CPP_GGML_OSHMEM_H__
#define __LLAMA_CPP_GGML_OSHMEM_H__
struct ggml_context;
struct ggml_tensor;
struct ggml_cgraph;
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_openshmem_context;
void ggml_openshmem_backend_init(void);
void ggml_openshmem_backend_free(void);
struct ggml_openshmem_context * ggml_openshmem_init(void);
void ggml_openshmem_free(struct ggml_openshmem_context * ctx);
int ggml_openshmem_rank(struct ggml_openshmem_context * ctx);
void ggml_openshmem_eval_init(
struct ggml_openshmem_context * ctx_openshmem,
int * n_tokens,
int * n_past,
int * n_threads);
void ggml_openshmem_graph_compute_pre(
struct ggml_openshmem_context * ctx_openshmem,
struct ggml_cgraph * gf,
int n_layers);
void ggml_openshmem_graph_compute_post(
struct ggml_openshmem_context * ctx_openshmem,
struct ggml_cgraph * gf,
int n_layers);
#ifdef __cplusplus
}
#endif
#endif

View file

@ -19,6 +19,9 @@
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
# include "ggml-mpi.h" # include "ggml-mpi.h"
#endif #endif
#ifdef GGML_USE_OPENSHMEM
# include "ggml-oshmem.h"
#endif
#ifndef QK_K #ifndef QK_K
# ifdef GGML_QKK_64 # ifdef GGML_QKK_64
# define QK_K 64 # define QK_K 64