This commit is contained in:
slaren 2024-04-06 15:19:47 +02:00
parent ea2b79534e
commit 1b5d78d3ee

View file

@ -1961,22 +1961,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
}
struct mmid_row_mapping {
int64_t i1;
int64_t i2;
int32_t i1;
int32_t i2;
};
static __global__ void k_copy_src1_to_contiguous(const char * src1_original, char * src1_contiguous,
int * cur_src1_row, mmid_row_mapping * row_mapping,
const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0,
int64_t ids_ne1, int64_t n_ids,
int64_t ne11,
static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
const char * ids_dev, int64_t i02, size_t ids_nb1, size_t ids_nb0,
int64_t ne11, int64_t ne10,
size_t nb11, size_t nb12) {
int64_t iid1 = blockIdx.x;
int64_t id = blockIdx.y;
if (iid1 >= ids_ne1 || id >= n_ids) {
return;
}
int32_t iid1 = blockIdx.x;
int32_t id = blockIdx.y;
const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0);
@ -1994,31 +1989,27 @@ static __global__ void k_copy_src1_to_contiguous(const char * src1_original, cha
}
__syncthreads();
const char * src1_row_original = src1_original + i11*nb11 + i12*nb12;
char * src1_row_contiguous = src1_contiguous + src1_row*nb11;
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
for (int i = threadIdx.x; i < nb11; i += blockDim.x) {
for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
src1_row_contiguous[i] = src1_row_original[i];
}
}
static __global__ void k_copy_dst_from_contiguous(char * dst_original, const char * dst_contiguous,
const mmid_row_mapping * row_mapping,
int64_t n_rows,
int64_t nb1, int64_t nb2) {
int64_t i = blockIdx.x;
static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
const mmid_row_mapping * __restrict__ row_mapping,
int64_t ne0,
size_t nb1, size_t nb2) {
int32_t i = blockIdx.x;
if (i >= n_rows) {
return;
}
const int32_t i1 = row_mapping[i].i1;
const int32_t i2 = row_mapping[i].i2;
const int64_t i1 = row_mapping[i].i1;
const int64_t i2 = row_mapping[i].i2;
const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
const char * dst_row_contiguous = dst_contiguous + i*nb1;
char * dst_row_original = dst_original + i1*nb1 + i2*nb2;
for (int j = threadIdx.x; j < nb1; j += blockDim.x) {
for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
dst_row_original[j] = dst_row_contiguous[j];
}
}
@ -2129,14 +2120,13 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
{
dim3 block_dims(std::min((uint)nb11, 1024u));
dim3 block_dims(std::min((uint)ne10, 512u));
dim3 grid_dims(ids->ne[1], n_ids);
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
src1_original, src1_contiguous.get(),
dev_cur_src1_row.get(), dev_row_mapping.get(),
ids_dev, i02, ids->nb[1], ids->nb[0],
ids->ne[1], n_ids,
ne11,
ne11, ne10,
nb11, nb12);
CUDA_CHECK(cudaGetLastError());
}
@ -2161,12 +2151,13 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
#ifndef MMID_MEMCPY
{
dim3 block_dims(std::min((uint)nb1, 1024u));
dim3 block_dims(std::min((uint)ne0, 512u));
dim3 grid_dims(num_src1_rows);
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
dst_original, dst_contiguous.get(),
dev_row_mapping.get(),
num_src1_rows, nb1, nb2);
ne0,
nb1, nb2);
CUDA_CHECK(cudaGetLastError());
}
#endif