llama : add basic support for offloading moe with CUDA
This commit is contained in:
parent
2cbcba829f
commit
06dfde3e94
3 changed files with 61 additions and 19 deletions
33
ggml-cuda.cu
33
ggml-cuda.cu
|
@ -8242,15 +8242,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
// TODO: mmq/mmv support
|
||||
#endif
|
||||
|
||||
const struct ggml_tensor * ids = src0;
|
||||
const int32_t id = dst->op_params[0];
|
||||
const int32_t n_as = dst->op_params[1];
|
||||
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
|
||||
|
||||
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||
const struct ggml_tensor * ids = src0;
|
||||
const int32_t id = ((int32_t *) dst->op_params)[0];
|
||||
const int32_t n_as = ((int32_t *) dst->op_params)[1];
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
|
||||
if (ids->backend == GGML_BACKEND_GPU) {
|
||||
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
} else {
|
||||
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
|
||||
}
|
||||
|
||||
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
|
||||
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
@ -8264,7 +8270,9 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
src1_row.ne[1] = 1;
|
||||
dst_row.ne[1] = 1;
|
||||
|
||||
src1_row.extra = &src1_row_extra;
|
||||
if (src1->backend == GGML_BACKEND_GPU) {
|
||||
src1_row.extra = &src1_row_extra;
|
||||
}
|
||||
dst_row.extra = &dst_row_extra;
|
||||
|
||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||
|
@ -8278,7 +8286,12 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
|
||||
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
|
||||
if (src1->backend == GGML_BACKEND_GPU) {
|
||||
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
|
||||
} else {
|
||||
src1_row.data = (char *) src1->data + i01*src1->nb[1];
|
||||
}
|
||||
|
||||
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
|
||||
|
||||
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
|
||||
|
@ -8694,7 +8707,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|||
func = ggml_cuda_repeat;
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
func = ggml_cuda_get_rows;
|
||||
if (ggml_is_contiguous(tensor->src[1])) {
|
||||
func = ggml_cuda_get_rows;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
func = ggml_cuda_dup;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue