nix: add cuda, use a symlinked toolkit instead for cmake

This commit is contained in:
Green Sky 2023-09-18 15:48:24 +02:00
parent 51a7cf5c6e
commit 9d4365df40
No known key found for this signature in database

View file

@ -35,6 +35,20 @@
); );
pkgs = import nixpkgs { inherit system; }; pkgs = import nixpkgs { inherit system; };
nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ]; nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
cudatoolkit_joined = with pkgs; symlinkJoin {
# HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
# see https://github.com/NixOS/nixpkgs/issues/224291
# copied from jaxlib
name = "${cudaPackages.cudatoolkit.name}-merged";
paths = [
cudaPackages.cudatoolkit.lib
cudaPackages.cudatoolkit.out
] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
# for some reason some of the required libs are in the targets/x86_64-linux
# directory; not sure why but this works around it
"${cudaPackages.cudatoolkit}/targets/${system}"
];
};
llama-python = llama-python =
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]); pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
postPatch = '' postPatch = ''
@ -70,6 +84,13 @@
"-DLLAMA_CLBLAST=ON" "-DLLAMA_CLBLAST=ON"
]; ];
}; };
packages.cuda = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall;
buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
cmakeFlags = cmakeFlags ++ [
"-DLLAMA_CUBLAS=ON"
];
};
packages.rocm = pkgs.stdenv.mkDerivation { packages.rocm = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall; inherit name src meta postPatch nativeBuildInputs postInstall;
buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ]; buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];