add cuda_pool_alloc, refactor most pool allocations
ggml-ci
This commit is contained in:
parent
545f23d07b
commit
110b5055da
1 changed files with 83 additions and 109 deletions
192
ggml-cuda.cu
192
ggml-cuda.cu
|
@ -220,6 +220,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
|
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||||
default: return "unknown error";
|
default: return "unknown error";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
#endif // CUDART_VERSION >= 12000
|
#endif // CUDART_VERSION >= 12000
|
||||||
|
|
||||||
static const char * cu_get_error_str(CUresult err) {
|
static const char * cu_get_error_str(CUresult err) {
|
||||||
|
@ -6739,6 +6740,39 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||||
#define ggml_cuda_pool_free ggml_cuda_pool_free_leg
|
#define ggml_cuda_pool_free ggml_cuda_pool_free_leg
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct cuda_pool_alloc {
|
||||||
|
T * ptr = nullptr;
|
||||||
|
size_t act_size = 0;
|
||||||
|
|
||||||
|
// size is in number of elements
|
||||||
|
T * alloc(size_t size) {
|
||||||
|
GGML_ASSERT(ptr == nullptr);
|
||||||
|
ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->act_size);
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
cuda_pool_alloc(size_t size) {
|
||||||
|
alloc(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
~cuda_pool_alloc() {
|
||||||
|
if (ptr != nullptr) {
|
||||||
|
ggml_cuda_pool_free(ptr, act_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
T * get() {
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
cuda_pool_alloc() = default;
|
||||||
|
cuda_pool_alloc(const cuda_pool_alloc &) = delete;
|
||||||
|
cuda_pool_alloc(cuda_pool_alloc &&) = delete;
|
||||||
|
cuda_pool_alloc& operator=(const cuda_pool_alloc &) = delete;
|
||||||
|
cuda_pool_alloc& operator=(cuda_pool_alloc &&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
static bool g_cublas_loaded = false;
|
static bool g_cublas_loaded = false;
|
||||||
|
|
||||||
bool ggml_cublas_loaded(void) {
|
bool ggml_cublas_loaded(void) {
|
||||||
|
@ -7432,8 +7466,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
|
|
||||||
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
||||||
#ifdef GGML_CUDA_F16
|
#ifdef GGML_CUDA_F16
|
||||||
size_t ash;
|
cuda_pool_alloc<half> src1_dfloat_a;
|
||||||
dfloat * src1_dfloat = nullptr; // dfloat == half
|
half * src1_dfloat = nullptr; // dfloat == half
|
||||||
|
|
||||||
bool src1_convert_f16 =
|
bool src1_convert_f16 =
|
||||||
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
|
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
|
||||||
|
@ -7441,7 +7475,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
|
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
|
||||||
|
|
||||||
if (src1_convert_f16) {
|
if (src1_convert_f16) {
|
||||||
src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
|
src1_dfloat = src1_dfloat_a.alloc(ne00);
|
||||||
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
|
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
|
||||||
ne00, 1, sizeof(float), 0, 0,
|
ne00, 1, sizeof(float), 0, 0,
|
||||||
ne00, 1, sizeof(half), 0, 0, stream);
|
ne00, 1, sizeof(half), 0, 0, stream);
|
||||||
|
@ -7489,12 +7523,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_CUDA_F16
|
|
||||||
if (src1_convert_f16) {
|
|
||||||
ggml_cuda_pool_free(src1_dfloat, ash);
|
|
||||||
}
|
|
||||||
#endif // GGML_CUDA_F16
|
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
(void) src1_ddq_i;
|
(void) src1_ddq_i;
|
||||||
|
@ -7529,29 +7557,26 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
|
|
||||||
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||||
half * src0_as_f16 = nullptr;
|
cuda_pool_alloc<half> src0_as_f16;
|
||||||
size_t src0_as = 0;
|
|
||||||
if (src0->type != GGML_TYPE_F16) {
|
if (src0->type != GGML_TYPE_F16) {
|
||||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
||||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||||
size_t ne = row_diff*ne00;
|
size_t ne = row_diff*ne00;
|
||||||
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
|
src0_as_f16.alloc(ne);
|
||||||
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
|
||||||
}
|
}
|
||||||
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
|
||||||
|
|
||||||
half * src1_as_f16 = nullptr;
|
cuda_pool_alloc<half> src1_as_f16;
|
||||||
size_t src1_as = 0;
|
|
||||||
if (src1->type != GGML_TYPE_F16) {
|
if (src1->type != GGML_TYPE_F16) {
|
||||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||||
size_t ne = src1_ncols*ne10;
|
size_t ne = src1_ncols*ne10;
|
||||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
|
src1_as_f16.alloc(ne);
|
||||||
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
|
||||||
}
|
}
|
||||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
|
||||||
size_t dst_as = 0;
|
cuda_pool_alloc<half> dst_f16(row_diff*src1_ncols);
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
|
||||||
|
|
||||||
const half alpha_f16 = 1.0f;
|
const half alpha_f16 = 1.0f;
|
||||||
const half beta_f16 = 0.0f;
|
const half beta_f16 = 0.0f;
|
||||||
|
@ -7560,36 +7585,25 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
row_diff, src1_ncols, ne10,
|
row_diff, src1_ncols, ne10,
|
||||||
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
||||||
src1_ptr, CUDA_R_16F, ne10,
|
src1_ptr, CUDA_R_16F, ne10,
|
||||||
&beta_f16, dst_f16, CUDA_R_16F, ldc,
|
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
|
||||||
CUBLAS_COMPUTE_16F,
|
CUBLAS_COMPUTE_16F,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
|
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||||
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||||
|
|
||||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
|
||||||
|
|
||||||
if (src1_as != 0) {
|
|
||||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (src0_as != 0) {
|
|
||||||
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
float * src0_ddq_as_f32 = nullptr;
|
cuda_pool_alloc<float> src0_ddq_as_f32;
|
||||||
size_t src0_as = 0;
|
|
||||||
|
|
||||||
if (src0->type != GGML_TYPE_F32) {
|
if (src0->type != GGML_TYPE_F32) {
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||||
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
|
src0_ddq_as_f32.alloc(row_diff*ne00);
|
||||||
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
|
||||||
}
|
}
|
||||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
||||||
|
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
|
@ -7601,10 +7615,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
&alpha, src0_ddf_i, ne00,
|
&alpha, src0_ddf_i, ne00,
|
||||||
src1_ddf_i, ne10,
|
src1_ddf_i, ne10,
|
||||||
&beta, dst_dd_i, ldc));
|
&beta, dst_dd_i, ldc));
|
||||||
|
|
||||||
if (src0_as != 0) {
|
|
||||||
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
(void) dst;
|
(void) dst;
|
||||||
|
@ -7896,18 +7906,17 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
||||||
float * src1_ddf = nullptr;
|
float * src1_ddf = nullptr;
|
||||||
float * dst_ddf = nullptr;
|
float * dst_ddf = nullptr;
|
||||||
|
|
||||||
// as = actual size
|
cuda_pool_alloc<float> src0_f;
|
||||||
size_t src0_asf = 0;
|
cuda_pool_alloc<float> src1_f;
|
||||||
size_t src1_asf = 0;
|
cuda_pool_alloc<float> dst_f;
|
||||||
size_t dst_asf = 0;
|
|
||||||
|
|
||||||
ggml_cuda_set_device(g_main_device);
|
ggml_cuda_set_device(g_main_device);
|
||||||
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||||
|
|
||||||
if (src0_on_device) {
|
if (src0_on_device) {
|
||||||
src0_ddf = (float *) src0_extra->data_device[g_main_device];
|
src0_ddf = (float *) src0_extra->data_device[g_main_device];
|
||||||
} else {
|
} else {
|
||||||
src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf);
|
src0_ddf = src0_f.alloc(ggml_nelements(src0));
|
||||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
|
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7915,14 +7924,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
||||||
if (src1_on_device) {
|
if (src1_on_device) {
|
||||||
src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||||
} else {
|
} else {
|
||||||
src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf);
|
src1_ddf = src1_f.alloc(ggml_nelements(src1));
|
||||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
|
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (dst_on_device) {
|
if (dst_on_device) {
|
||||||
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||||
} else {
|
} else {
|
||||||
dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf);
|
dst_ddf = dst_f.alloc(ggml_nelements(dst));
|
||||||
}
|
}
|
||||||
|
|
||||||
// do the computation
|
// do the computation
|
||||||
|
@ -7934,16 +7943,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
||||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
|
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dst_asf > 0) {
|
|
||||||
ggml_cuda_pool_free(dst_ddf, dst_asf);
|
|
||||||
}
|
|
||||||
if (src1_asf > 0) {
|
|
||||||
ggml_cuda_pool_free(src1_ddf, src1_asf);
|
|
||||||
}
|
|
||||||
if (src0_asf > 0) {
|
|
||||||
ggml_cuda_pool_free(src0_ddf, src0_asf);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst->backend == GGML_BACKEND_CPU) {
|
if (dst->backend == GGML_BACKEND_CPU) {
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
@ -8516,14 +8515,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||||
|
|
||||||
size_t src1_as = 0;
|
cuda_pool_alloc<half> src1_as_f16(ne1);
|
||||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream);
|
||||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
|
||||||
|
|
||||||
size_t dst_as = 0;
|
cuda_pool_alloc<half> dst_f16;
|
||||||
|
char * dst_t;
|
||||||
half * dst_f16 = nullptr;
|
|
||||||
char * dst_t = nullptr;
|
|
||||||
|
|
||||||
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||||
cudaDataType_t cu_data_type = CUDA_R_16F;
|
cudaDataType_t cu_data_type = CUDA_R_16F;
|
||||||
|
@ -8542,8 +8538,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
const void * beta = &beta_f16;
|
const void * beta = &beta_f16;
|
||||||
|
|
||||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||||
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
dst_t = (char *) dst_f16.alloc(ne);
|
||||||
dst_t = (char *) dst_f16;
|
|
||||||
|
|
||||||
nbd2 /= sizeof(float) / sizeof(half);
|
nbd2 /= sizeof(float) / sizeof(half);
|
||||||
nbd3 /= sizeof(float) / sizeof(half);
|
nbd3 /= sizeof(float) / sizeof(half);
|
||||||
|
@ -8590,9 +8585,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
||||||
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
(const char *) src1_as_f16.get(), CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
||||||
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
|
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
|
||||||
ne12*ne13,
|
ne12*ne13,
|
||||||
cu_compute_type,
|
cu_compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
|
@ -8600,19 +8595,13 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
// use cublasGemmBatchedEx
|
// use cublasGemmBatchedEx
|
||||||
const int ne23 = ne12*ne13;
|
const int ne23 = ne12*ne13;
|
||||||
|
|
||||||
const void ** ptrs_src = nullptr;
|
cuda_pool_alloc<const void *> ptrs_src(2*ne23);
|
||||||
void ** ptrs_dst = nullptr;
|
cuda_pool_alloc< void *> ptrs_dst(1*ne23);
|
||||||
|
|
||||||
size_t ptrs_src_s = 0;
|
|
||||||
size_t ptrs_dst_s = 0;
|
|
||||||
|
|
||||||
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
|
|
||||||
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
|
|
||||||
|
|
||||||
dim3 block_dims(ne13, ne12);
|
dim3 block_dims(ne13, ne12);
|
||||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||||
src0_as_f16, src1_as_f16, dst_t,
|
src0_as_f16, src1_as_f16.get(), dst_t,
|
||||||
ptrs_src, ptrs_dst,
|
ptrs_src.get(), ptrs_dst.get(),
|
||||||
ne12, ne13,
|
ne12, ne13,
|
||||||
ne23,
|
ne23,
|
||||||
nb02, nb03,
|
nb02, nb03,
|
||||||
|
@ -8624,30 +8613,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
||||||
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
||||||
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
|
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
|
||||||
ne23,
|
ne23,
|
||||||
cu_compute_type,
|
cu_compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
|
|
||||||
if (ptrs_dst_s != 0) {
|
|
||||||
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
|
|
||||||
}
|
|
||||||
if (ptrs_src_s != 0) {
|
|
||||||
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream);
|
||||||
|
|
||||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
@ -8974,12 +8952,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
||||||
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
|
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
size_t as_src1, as_dst;
|
cuda_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
|
||||||
char * src1_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(src1), &as_src1);
|
cuda_pool_alloc<char> dst_contiguous(sizeof(float)*ggml_nelements(dst));
|
||||||
char * dst_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(dst), &as_dst);
|
|
||||||
|
|
||||||
src1_row_extra.data_device[g_main_device] = src1_contiguous;
|
src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
|
||||||
dst_row_extra.data_device[g_main_device] = dst_contiguous;
|
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
|
||||||
|
|
||||||
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
|
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
|
||||||
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
|
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
|
||||||
|
@ -8999,7 +8976,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
||||||
|
|
||||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||||
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
|
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
|
||||||
nb11, src1_kind, stream));
|
nb11, src1_kind, stream));
|
||||||
num_src1_rows++;
|
num_src1_rows++;
|
||||||
}
|
}
|
||||||
|
@ -9031,14 +9008,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
||||||
|
|
||||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||||
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
|
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
|
||||||
nb1, dst_kind, stream));
|
nb1, dst_kind, stream));
|
||||||
num_src1_rows++;
|
num_src1_rows++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cuda_pool_free(dst_contiguous, as_dst);
|
|
||||||
ggml_cuda_pool_free(src1_contiguous, as_src1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dst->backend == GGML_BACKEND_CPU) {
|
if (dst->backend == GGML_BACKEND_CPU) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue