From 7bd8d8c6d7de67975fbb9681990d4d4af5b6bbab Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Tue, 26 Dec 2023 22:23:30 +0000 Subject: [PATCH] nix: explicit mpi support --- .devops/nix/package.nix | 18 +++++++++++------- flake.nix | 3 +++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix index 2d0099457..5f2a7c9f4 100644 --- a/.devops/nix/package.nix +++ b/.devops/nix/package.nix @@ -22,6 +22,7 @@ ], useCuda ? config.cudaSupport, useMetalKit ? stdenv.isAarch64 && stdenv.isDarwin && !useOpenCL, + useMpi ? false, # Increases the runtime closure size by ~700M useOpenCL ? false, useRocm ? config.rocmSupport, llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake @@ -42,11 +43,12 @@ let effectiveStdenv = if useCuda then cudaPackages.backendStdenv else inputs.stdenv; suffices = - lib.optionals useOpenCL [ "OpenCL" ] + lib.optionals useBlas [ "BLAS" ] ++ lib.optionals useCuda [ "CUDA" ] - ++ lib.optionals useRocm [ "ROCm" ] ++ lib.optionals useMetalKit [ "MetalKit" ] - ++ lib.optionals useBlas [ "BLAS" ]; + ++ lib.optionals useMpi [ "MPI" ] + ++ lib.optionals useOpenCL [ "OpenCL" ] + ++ lib.optionals useRocm [ "ROCm" ]; pnameSuffix = strings.optionalString (suffices != [ ]) @@ -149,11 +151,11 @@ effectiveStdenv.mkDerivation ( ]; buildInputs = - [ mpi ] - ++ optionals useOpenCL [ clblast ] + optionals effectiveStdenv.isDarwin darwinBuildInputs ++ optionals useCuda cudaBuildInputs - ++ optionals useRocm rocmBuildInputs - ++ optionals effectiveStdenv.isDarwin darwinBuildInputs; + ++ optionals useMpi [ mpi ] + ++ optionals useOpenCL [ clblast ] + ++ optionals useRocm rocmBuildInputs; cmakeFlags = [ @@ -166,6 +168,7 @@ effectiveStdenv.mkDerivation ( (cmakeBool "LLAMA_CUBLAS" useCuda) (cmakeBool "LLAMA_HIPBLAS" useRocm) (cmakeBool "LLAMA_METAL" useMetalKit) + (cmakeBool "LLAMA_MPI" useMpi) ] ++ optionals useCuda [ ( @@ -203,6 +206,7 @@ effectiveStdenv.mkDerivation ( useBlas useCuda useMetalKit + useMpi useOpenCL useRocm ; diff --git a/flake.nix b/flake.nix index 3575cbf12..d240ececa 100644 --- a/flake.nix +++ b/flake.nix @@ -85,6 +85,9 @@ opencl = config.packages.default.override { useOpenCL = true; }; cuda = (pkgsCuda.callPackage .devops/nix/scope.nix { inherit llamaVersion; }).llama-cpp; rocm = (pkgsRocm.callPackage .devops/nix/scope.nix { inherit llamaVersion; }).llama-cpp; + + mpi-cpu = config.packages.default.override { useMpi = true; }; + mpi-cuda = config.packages.default.override { useMpi = true; }; }; }; };