Fix multi GPU on multiple amd architectures with rocblas_initialize() (#5)

* initialize rocblas
This commit is contained in:
YellowRoseCx 2023-07-24 03:52:01 -05:00 committed by GitHub
parent 3db70b5f0a
commit 1f6294dc44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -10,6 +10,7 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hipblas/hipblas.h> #include <hipblas/hipblas.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include "rocblas/rocblas.h"
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
@ -2531,6 +2532,10 @@ void ggml_init_cublas() {
static bool initialized = false; static bool initialized = false;
if (!initialized) { if (!initialized) {
#ifdef GGML_USE_HIPBLAS
rocblas_initialize();
hipDeviceSynchronize();
#endif
CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0; int64_t total_vram = 0;