make MUL_MAT_SRC1_COL_STRIDE conditional on runtime mmq
This commit is contained in:
parent
12fb1c58ec
commit
dd71a35cc8
1 changed files with 7 additions and 3 deletions
|
@ -467,6 +467,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
|
||||||
#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
|
#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
|
||||||
#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
|
#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
|
||||||
|
|
||||||
|
#define MUL_MAT_SRC1_COL_STRIDE_MMQ 128
|
||||||
#define MUL_MAT_SRC1_COL_STRIDE 4096
|
#define MUL_MAT_SRC1_COL_STRIDE 4096
|
||||||
|
|
||||||
#define MAX_STREAMS 8
|
#define MAX_STREAMS 8
|
||||||
|
@ -7158,7 +7159,10 @@ static void ggml_cuda_op_mul_mat(
|
||||||
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
|
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
|
const int64_t src1_col_stride = !split || used_devices == 1 ? ne11 :
|
||||||
|
convert_src1_to_q8_1 ? MUL_MAT_SRC1_COL_STRIDE_MMQ :
|
||||||
|
MUL_MAT_SRC1_COL_STRIDE;
|
||||||
|
|
||||||
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
|
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
|
||||||
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
|
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
|
||||||
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
|
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
|
||||||
|
@ -7296,7 +7300,7 @@ static void ggml_cuda_op_mul_mat(
|
||||||
|
|
||||||
// main device waits for all other devices to be finished
|
// main device waits for all other devices to be finished
|
||||||
if (split && g_device_count > 1) {
|
if (split && g_device_count > 1) {
|
||||||
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
|
int64_t is_max = (ne11 + src1_col_stride - 1) / src1_col_stride;
|
||||||
is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;
|
is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;
|
||||||
|
|
||||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue