This commit is contained in:
slaren 2024-04-06 15:19:47 +02:00
parent ea2b79534e
commit 1b5d78d3ee

View file

@ -1961,22 +1961,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
} }
struct mmid_row_mapping { struct mmid_row_mapping {
int64_t i1; int32_t i1;
int64_t i2; int32_t i2;
}; };
static __global__ void k_copy_src1_to_contiguous(const char * src1_original, char * src1_contiguous, static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
int * cur_src1_row, mmid_row_mapping * row_mapping, int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0, const char * ids_dev, int64_t i02, size_t ids_nb1, size_t ids_nb0,
int64_t ids_ne1, int64_t n_ids, int64_t ne11, int64_t ne10,
int64_t ne11,
size_t nb11, size_t nb12) { size_t nb11, size_t nb12) {
int64_t iid1 = blockIdx.x; int32_t iid1 = blockIdx.x;
int64_t id = blockIdx.y; int32_t id = blockIdx.y;
if (iid1 >= ids_ne1 || id >= n_ids) {
return;
}
const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0); const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0);
@ -1994,31 +1989,27 @@ static __global__ void k_copy_src1_to_contiguous(const char * src1_original, cha
} }
__syncthreads(); __syncthreads();
const char * src1_row_original = src1_original + i11*nb11 + i12*nb12; const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
char * src1_row_contiguous = src1_contiguous + src1_row*nb11; float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
for (int i = threadIdx.x; i < nb11; i += blockDim.x) { for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
src1_row_contiguous[i] = src1_row_original[i]; src1_row_contiguous[i] = src1_row_original[i];
} }
} }
static __global__ void k_copy_dst_from_contiguous(char * dst_original, const char * dst_contiguous, static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
const mmid_row_mapping * row_mapping, const mmid_row_mapping * __restrict__ row_mapping,
int64_t n_rows, int64_t ne0,
int64_t nb1, int64_t nb2) { size_t nb1, size_t nb2) {
int64_t i = blockIdx.x; int32_t i = blockIdx.x;
if (i >= n_rows) { const int32_t i1 = row_mapping[i].i1;
return; const int32_t i2 = row_mapping[i].i2;
}
const int64_t i1 = row_mapping[i].i1; const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
const int64_t i2 = row_mapping[i].i2; float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
const char * dst_row_contiguous = dst_contiguous + i*nb1; for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
char * dst_row_original = dst_original + i1*nb1 + i2*nb2;
for (int j = threadIdx.x; j < nb1; j += blockDim.x) {
dst_row_original[j] = dst_row_contiguous[j]; dst_row_original[j] = dst_row_contiguous[j];
} }
} }
@ -2129,14 +2120,13 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
{ {
dim3 block_dims(std::min((uint)nb11, 1024u)); dim3 block_dims(std::min((uint)ne10, 512u));
dim3 grid_dims(ids->ne[1], n_ids); dim3 grid_dims(ids->ne[1], n_ids);
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>( k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
src1_original, src1_contiguous.get(), src1_original, src1_contiguous.get(),
dev_cur_src1_row.get(), dev_row_mapping.get(), dev_cur_src1_row.get(), dev_row_mapping.get(),
ids_dev, i02, ids->nb[1], ids->nb[0], ids_dev, i02, ids->nb[1], ids->nb[0],
ids->ne[1], n_ids, ne11, ne10,
ne11,
nb11, nb12); nb11, nb12);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
} }
@ -2161,12 +2151,13 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
#ifndef MMID_MEMCPY #ifndef MMID_MEMCPY
{ {
dim3 block_dims(std::min((uint)nb1, 1024u)); dim3 block_dims(std::min((uint)ne0, 512u));
dim3 grid_dims(num_src1_rows); dim3 grid_dims(num_src1_rows);
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>( k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
dst_original, dst_contiguous.get(), dst_original, dst_contiguous.get(),
dev_row_mapping.get(), dev_row_mapping.get(),
num_src1_rows, nb1, nb2); ne0,
nb1, nb2);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
} }
#endif #endif