CUDA: mmq CLI option, fixed mmq build issues (#2453)

This commit is contained in:
Johannes Gäßler 2023-07-31 15:44:35 +02:00 committed by GitHub
parent 1215ed7d5c
commit 0728c5a8b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 67 additions and 27 deletions

View file

@ -3898,10 +3898,9 @@ static size_t g_scratch_offset = 0;
static int g_device_count = -1;
static int g_main_device = 0;
#ifndef GGML_CUDA_FORCE_DMMV
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
#endif
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static bool g_mul_mat_q = false;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@ -3923,9 +3922,7 @@ void ggml_init_cublas() {
g_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
#ifndef GGML_CUDA_FORCE_DMMV
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
#endif
}
for (int id = 0; id < g_device_count; ++id) {
g_tensor_split[id] /= total_vram;
@ -4278,6 +4275,7 @@ inline void ggml_cuda_op_mul_mat_vec(
#ifdef GGML_CUDA_FORCE_DMMV
const bool use_mul_mat_vec_q = false;
(void) g_compute_capabilities[0];
#else
int id;
CUDA_CHECK(cudaGetDevice(&id));
@ -5021,12 +5019,14 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
} else {
#ifdef GGML_CUDA_CUBLAS
const bool use_mul_mat_q = false;
#else
const bool use_mul_mat_q = ggml_is_quantized(src0->type);
#endif // GGML_CUDA_CUBLAS
if (use_mul_mat_q) {
int min_compute_capability = INT_MAX;
for (int id = 0; id < g_device_count; ++id) {
if (min_compute_capability > g_compute_capabilities[id]) {
min_compute_capability = g_compute_capabilities[id];
}
}
if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
} else {
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
@ -5320,6 +5320,10 @@ void ggml_cuda_set_main_device(int main_device) {
}
}
void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
g_mul_mat_q = mul_mat_q;
}
void ggml_cuda_set_scratch_size(size_t scratch_size) {
g_scratch_size = scratch_size;
}