CPU/CUDA: fix (GQA) mul mat back, add CUDA support (#11380)
This commit is contained in:
parent
1af6945eb0
commit
8137b4bb2b
7 changed files with 156 additions and 61 deletions
|
@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
|||
|
||||
template <typename T>
|
||||
static __global__ void k_repeat_back(
|
||||
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
|
||||
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
|
||||
|
||||
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
|
||||
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
|
||||
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
|
||||
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
|
||||
const int64_t tid2 = tid23 % ne2;
|
||||
const int64_t tid3 = tid23 / ne2;
|
||||
|
||||
if (tid0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
||||
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
||||
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
|
||||
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
|
||||
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
|
||||
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
||||
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
||||
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
|
||||
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
||||
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
|
@ -274,12 +279,14 @@ struct bin_bcast_cuda {
|
|||
|
||||
template <typename T>
|
||||
static void repeat_back_cuda(
|
||||
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
|
||||
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
|
||||
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
|
||||
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
|
||||
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
|
||||
(src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
|
||||
}
|
||||
|
||||
template<class op>
|
||||
|
@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
GGML_ASSERT(src0->ne[3] == 1);
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
GGML_ASSERT(dst->ne[3] == 1);
|
||||
GGML_ASSERT(ne2*ne3 <= (1 << 15));
|
||||
|
||||
const size_t ts = ggml_type_size(src0->type);
|
||||
const size_t s00 = nb00 / ts;
|
||||
const size_t s01 = nb01 / ts;
|
||||
const size_t s02 = nb02 / ts;
|
||||
const size_t s03 = nb03 / ts;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
||||
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ASSERT(false);
|
||||
|
|
|
@ -3002,7 +3002,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
} break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
|
||||
return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
|
|
|
@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
|
||||
CUBLAS_CHECK(cublasSetStream(handle, stream));
|
||||
|
||||
const int64_t lda = nb01 / sizeof(float);
|
||||
const int64_t ldc = nb1 / sizeof(float);
|
||||
|
||||
const bool src1_T = ggml_is_transposed(src1);
|
||||
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
||||
|
@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue