Add cmake for MUSA support

This commit is contained in:
dixyes 2024-04-16 14:18:38 +08:00
parent b7499e0460
commit dfb6a0139c
7 changed files with 191 additions and 1 deletions

View file

@ -102,6 +102,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
option(LLAMA_MUSA "llama: use MUSA" OFF)
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_VULKAN "llama: use Vulkan" OFF)
@ -574,6 +575,49 @@ if (LLAMA_HIPBLAS)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
endif()
if (LLAMA_MUSA)
option(MUSA_ARCH "MUSA architecture" "21")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
find_package(MUSA REQUIRED)
message(STATUS "MUSA found")
enable_language(MUSA)
set(GGML_HEADERS_MUSA ggml-cuda.h)
file(GLOB GGML_SOURCES_MUSA "ggml-cuda/*.cu")
list(APPEND GGML_SOURCES_MUSA "ggml-cuda.cu")
add_compile_definitions(GGML_USE_MUSA GGML_USE_CUDA)
if (LLAMA_CUDA_FORCE_DMMV)
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
endif()
if (LLAMA_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()
if (LLAMA_CUDA_NO_PEER_COPY)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif()
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE MUSA)
if (LLAMA_STATIC)
message(FATAL_ERROR "Static linking not supported for MUSA")
endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} PUBLIC MUSA::musa MUSA::mublas MUSA::musart)
endif()
if (LLAMA_SYCL)
if (NOT LLAMA_SYCL_TARGET MATCHES "^(INTEL|NVIDIA)$")
message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL or NVIDIA")
@ -1160,6 +1204,7 @@ add_library(ggml OBJECT
${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN}
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
${GGML_SOURCES_MUSA} ${GGML_HEADERS_MUSA}
)
target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})

View file

@ -567,7 +567,7 @@ endif # LLAMA_HIPBLAS
ifdef LLAMA_MUSA
MUSA_PATH ?= /usr/local/musa
MUSA_ARCH ?= 10
MUSA_ARCH ?= 21
MCC ?= $(CCACHE) $(MUSA_PATH)/bin/mcc
LLAMA_CUDA_DMMV_X ?= 32
LLAMA_CUDA_MMV_Y ?= 1

View file

@ -0,0 +1,11 @@
set(CMAKE_MUSA_ARCHITECTURES "mp_${MUSA_ARCH}")
set(CMAKE_MUSA_COMPILER "${MUSA_MCC}")
set(CMAKE_MUSA_COMPILER_ID "Clang")
set(CMAKE_MUSA_COMPILER_ARG1 "")
set(CMAKE_MUSA_COMPILER_ENV_VAR "MCC")
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeMUSACompiler.cmake.in
${CMAKE_PLATFORM_INFO_DIR}/CMakeMUSACompiler.cmake
)

View file

@ -0,0 +1,6 @@
set(CMAKE_MUSA_COMPILER "@CMAKE_MUSA_COMPILER@")
set(CMAKE_MUSA_COMPILER_ARG1 "@CMAKE_MUSA_COMPILER_ARG1@")
set(CMAKE_MUSA_COMPILER_LOADED 1)
set(CMAKE_MUSA_SOURCE_FILE_EXTENSIONS mu;cu)
set(CMAKE_MUSA_OUTPUT_EXTENSION .o)
set(CMAKE_MUSA_COMPILER_ENV_VAR "MUSA")

View file

@ -0,0 +1,26 @@
# reuse cxx things
include(CMakeLanguageInformation)
include(CMakeCommonLanguageInclude)
include(Compiler/Clang)
__compiler_clang(MUSA)
__compiler_clang_cxx_standards(MUSA)
set(CMAKE_INCLUDE_FLAG_MUSA "-I")
set(CMAKE_MUSA_RUNTIME_LIBRARY_DEFAULT "SHARED")
set(CMAKE_MUSA_RUNTIME_LIBRARY_LINK_OPTIONS_STATIC "")
set(CMAKE_MUSA_RUNTIME_LIBRARY_LINK_OPTIONS_SHARED "")
# Populated by CMakeHIPInformation.cmake
set(CMAKE_MUSA_RUNTIME_LIBRARIES_STATIC "")
set(CMAKE_MUSA_RUNTIME_LIBRARIES_SHARED "")
# compile a C++ file into an object file
if(NOT CMAKE_MUSA_COMPILE_OBJECT)
set(CMAKE_MUSA_COMPILE_OBJECT
"<CMAKE_MUSA_COMPILER> -x musa --cuda-gpu-arch=${CMAKE_MUSA_ARCHITECTURES} -fPIC <DEFINES> <INCLUDES> <FLAGS> -o <OBJECT> -c <SOURCE>")
endif()

View file

@ -0,0 +1 @@
# do nothing, make cmake happy

101
cmake/FindMUSA.cmake Normal file
View file

@ -0,0 +1,101 @@
# find MUSA things
include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake)
include(${CMAKE_ROOT}/Modules/SelectLibraryConfigurations.cmake)
include(${CMAKE_ROOT}/Modules/CMakeFindDependencyMacro.cmake)
if(DEFINED ENV{MUSA_HOME})
set(MUSA_HOME $ENV{MUSA_HOME})
else()
set(MUSA_HOME /usr/local/musa)
endif()
set(MUSA_MCC ${MUSA_HOME}/bin/mcc)
if (DEFINED ENV{MUSA_ARCH})
set(MUSA_ARCH $ENV{MUSA_ARCH})
elseif(NOT MUSA_ARCH)
set(MUSA_ARCH "21")
endif()
if(NOT MUSA_INCLUDE_DIR)
set(MUSA_INCLUDE_DIR ${MUSA_HOME}/include)
endif()
if(NOT MUSA_LIBRARY_DIR)
set(MUSA_LIBRARY_DIR ${MUSA_HOME}/lib)
endif()
if(NOT MUSA_LIBRARIES)
find_library(
MUSA_MUSA_LIBRARY
NAMES libmusa.so
PATHS ${MUSA_LIBRARY_DIR}
)
find_library(
MUSA_MUBLAS_LIBRARY
NAMES libmublas.so
PATHS ${MUSA_LIBRARY_DIR}
)
find_library(
MUSA_MUSART_LIBRARY
NAMES libmusart.so
PATHS ${MUSA_LIBRARY_DIR}
)
if(MUSA_MUSA_LIBRARY AND MUSA_MUBLAS_LIBRARY AND MUSA_MUSART_LIBRARY)
set(MUSA_LIBRARIES "${MUSA_MUSA_LIBRARY};${MUSA_MUBLAS_LIBRARY};${MUSA_MUSART_LIBRARY}")
set(MUSA_MUSA_LIBRARY CACHE STRING "${MUSA_MUSA_LIBRARY}")
set(MUSA_MUBLAS_LIBRARY CACHE STRING "${MUSA_MUBLAS_LIBRARY}")
set(MUSA_MUSART_LIBRARY CACHE STRING "${MUSA_MUSART_LIBRARY}")
endif()
endif()
if(MUSA_LIBRARIES)
if(NOT TARGET MUSA::musa)
add_library(MUSA::musa SHARED IMPORTED)
set_target_properties(MUSA::musa PROPERTIES
IMPORTED_LOCATION ${MUSA_MUSA_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${MUSA_INCLUDE_DIR}
)
endif()
if(NOT TARGET MUSA::mublas)
add_library(MUSA::mublas SHARED IMPORTED)
set_target_properties(MUSA::mublas PROPERTIES
IMPORTED_LOCATION ${MUSA_MUBLAS_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${MUSA_INCLUDE_DIR}
)
endif()
if(NOT TARGET MUSA::musart)
add_library(MUSA::musart SHARED IMPORTED)
set_target_properties(MUSA::musart PROPERTIES
IMPORTED_LOCATION ${MUSA_MUSART_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${MUSA_INCLUDE_DIR}
)
endif()
set(MUSA_INCLUDE_DIR ${MUSA_INCLUDE_DIR} CACHE STRING "")
set(MUSA_LIBRARY_DIR ${MUSA_LIBRARY_DIR} CACHE STRING "")
set(MUSA_LIBRARIES ${MUSA_LIBRARIES} CACHE STRING "")
endif()
find_package_handle_standard_args(
MUSA
REQUIRED_VARS
MUSA_ARCH
MUSA_MCC
MUSA_INCLUDE_DIR
MUSA_LIBRARIES
MUSA_LIBRARY_DIR
)
mark_as_advanced(
MUSA_INCLUDE_DIR
MUSA_LIBRARIES
MUSA_LIBRARY_DIR
CMAKE_MUSA_ARCHITECTURES
CMAKE_MUSA_COMPILER
)