Removed obsolete code, fixed multi GPU
This commit is contained in:
parent
95120f1365
commit
8e3057b24b
1 changed files with 37 additions and 175 deletions
212
ggml-cuda.cu
212
ggml-cuda.cu
|
@ -743,7 +743,7 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y
|
|||
}
|
||||
}
|
||||
|
||||
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y) {
|
||||
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
|
||||
const half * x = (half *) vx;
|
||||
// const int col_x = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
// const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
|
@ -752,7 +752,6 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
|||
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
|
||||
const int nrows_y = ncols_x;
|
||||
const int ncols_dst = ncols_y;
|
||||
const int nrows_dst = nrows_x;
|
||||
const int row_dst = row_x;
|
||||
|
||||
|
@ -775,11 +774,11 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
|||
// y is not transposed but permuted
|
||||
const int iy = channel*nrows_y + row_y;
|
||||
|
||||
// dst is not transposed and not permuted
|
||||
|
||||
tmp += xi * y[iy];
|
||||
}
|
||||
const int idst = channel*ncols_dst*nrows_dst + row_dst;
|
||||
|
||||
// dst is not transposed and not permuted
|
||||
const int idst = channel*nrows_dst + row_dst;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
__syncthreads();
|
||||
|
@ -1143,10 +1142,10 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y, cudaStream_t stream) {
|
||||
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
|
||||
const dim3 block_nums(1, nrows_x, nchannels_x);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, ncols_y);
|
||||
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_vec_nc_f16_f32_cuda(
|
||||
|
@ -2024,14 +2023,6 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
|
|||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
|
||||
// if (strcmp(dst->name, "KQ") == 0) {
|
||||
// fprintf(stderr, "(%ld, %ld, %ld, %ld) + (%ld, %ld, %ld, %ld) -> (%ld, %ld, %ld, %ld)\n",
|
||||
// src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
// src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
||||
// dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// TODO: find the optimal values for these
|
||||
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
|
@ -2043,169 +2034,60 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
|
|||
return false;
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_p021_f16_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
GGML_ASSERT(!ggml_is_contiguous(src0) && !ggml_is_contiguous(src1));
|
||||
void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]);
|
||||
GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]);
|
||||
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
|
||||
GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
// GGML_ASSERT(src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
GGML_ASSERT(ne00 % 8 == 0);
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
const size_t src0_size = ggml_nbytes(src0);
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
GGML_ASSERT(ne10 % 4 == 0);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne12 == ne02);
|
||||
const size_t src1_size = ggml_nbytes(src1);
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
const size_t dst_size = ggml_nbytes(dst);
|
||||
|
||||
const int64_t nb1 = dst->nb[1];
|
||||
const int64_t nb2 = dst->nb[2];
|
||||
const int64_t nb3 = dst->nb[3];
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
||||
|
||||
void * src0_ddq;
|
||||
float * src1_ddf;
|
||||
float * dst_ddf;
|
||||
if (src0->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaMalloc(&src0_ddq, src0_size));
|
||||
CUDA_CHECK(cudaMemcpyAsync(src0_ddq, src0->data, src0_size, cudaMemcpyHostToDevice, cudaStream_main));
|
||||
} else {
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
src0_ddq = src0_extra->data_device[g_main_device];
|
||||
// CUDA_CHECK(cudaMemset(src0_ddq, 0, ggml_nbytes(src0)));
|
||||
}
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaMalloc(&src1_ddf, src1_size));
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddf, src1->data, src1_size, cudaMemcpyHostToDevice, cudaStream_main));
|
||||
} else {
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
// CUDA_CHECK(cudaMemset(src1_ddf, 0, ggml_nbytes(src1)));
|
||||
}
|
||||
if (dst->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaMalloc(&dst_ddf, dst_size));
|
||||
} else {
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
}
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
float * src1_ddf_i = src1_ddf + i11 * ne10*ne12;
|
||||
float * dst_ddf_i = dst_ddf + i11 * ne0;
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
|
||||
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf_i, dst_ddf_i, ne00, ne01, ne02, ne11, cudaStream_main);
|
||||
}
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
if (dst->backend == GGML_BACKEND_CPU) {
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const int64_t i = i3*ne2*ne1 + i2*ne1 + i1;
|
||||
float * dst_ddf_i = dst_ddf + i*ne0;
|
||||
float * dhf_dst_i = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, ne0*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_main));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
if (src0->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaFree(src0_ddq));
|
||||
}
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaFree(src1_ddf));
|
||||
}
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaFree(dst_ddf));
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_nc_f16_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(!ggml_is_permuted(src0));
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
// GGML_ASSERT(src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
const size_t src0_size = ggml_nbytes(src0);
|
||||
|
||||
const int64_t nb01 = src0->nb[1];
|
||||
const int64_t nb02 = src0->nb[2];
|
||||
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne12 == ne02);
|
||||
const size_t src1_size = ggml_nbytes(src1);
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
const size_t dst_size = ggml_nbytes(dst);
|
||||
|
||||
const int64_t nb1 = dst->nb[1];
|
||||
const int64_t nb2 = dst->nb[2];
|
||||
const int64_t nb3 = dst->nb[3];
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
||||
|
||||
void * src0_ddq;
|
||||
float * src1_ddf;
|
||||
float * dst_ddf;
|
||||
if (src0->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaMalloc(&src0_ddq, src0_size));
|
||||
CUDA_CHECK(cudaMemcpyAsync(src0_ddq, src0->data, src0_size, cudaMemcpyHostToDevice, cudaStream_main));
|
||||
} else {
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
src0_ddq = src0_extra->data_device[g_main_device];
|
||||
// CUDA_CHECK(cudaMemset(src0_ddq, 0, ggml_nbytes(src0)));
|
||||
}
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaMalloc(&src1_ddf, src1_size));
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddf, src1->data, src1_size, cudaMemcpyHostToDevice, cudaStream_main));
|
||||
} else {
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
// CUDA_CHECK(cudaMemset(src1_ddf, 0, ggml_nbytes(src1)));
|
||||
}
|
||||
if (dst->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaMalloc(&dst_ddf, dst_size));
|
||||
} else {
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
}
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
const int row_stride_x = nb01 / sizeof(half);
|
||||
const int channel_stride_x = nb02 / sizeof(half);
|
||||
|
@ -2213,37 +2095,16 @@ void ggml_cuda_mul_mat_vec_nc_f16_f32(const ggml_tensor * src0, const ggml_tenso
|
|||
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
if (dst->backend == GGML_BACKEND_CPU) {
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const int64_t i = i3*ne2*ne1 + i2*ne1 + i1;
|
||||
float * dst_ddf_i = dst_ddf + i*ne0;
|
||||
float * dhf_dst_i = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, ne0*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_main));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
if (src0->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaFree(src0_ddq));
|
||||
}
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaFree(src1_ddf));
|
||||
}
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaFree(dst_ddf));
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
ggml_cuda_mul_mat_p021_f16_f32(src0, src1, dst);
|
||||
} else if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
|
||||
ggml_cuda_mul_mat_vec_nc_f16_f32(src0, src1, dst);
|
||||
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
|
||||
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
|
||||
|
||||
if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
||||
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
|
||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||
}else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
|
@ -2288,6 +2149,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
|||
const int64_t nb11 = src1->nb[1];
|
||||
const int64_t nb12 = src1->nb[2];
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
||||
|
||||
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue