flatten rows for ggml_cuda_op
This commit is contained in:
parent
3b6a2ee414
commit
95120f1365
1 changed files with 46 additions and 37 deletions
83
ggml-cuda.cu
83
ggml-cuda.cu
|
@ -897,7 +897,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int n_past) {
|
||||
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
|
||||
|
@ -907,7 +907,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||
|
||||
const int i = row*ncols + col;
|
||||
// dst[i] = col > n_past + row ? -INFINITY : x[i];
|
||||
dst[i] = x[i] - (col > n_past + row) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
|
||||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
|
||||
}
|
||||
|
||||
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
||||
|
@ -1192,11 +1192,11 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
|
|||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
|
||||
}
|
||||
|
||||
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int n_past, cudaStream_t stream) {
|
||||
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, nrows_x, 1);
|
||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, n_past);
|
||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||
}
|
||||
|
||||
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
|
||||
|
@ -1659,12 +1659,13 @@ inline void ggml_cuda_op_diag_mask_inf(
|
|||
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t i01_diff = i01_high - i01_low;
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
|
||||
// compute
|
||||
diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_past, cudaStream_main);
|
||||
diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) dst;
|
||||
|
@ -1723,7 +1724,7 @@ inline void ggml_cuda_op_scale(
|
|||
}
|
||||
|
||||
static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
ggml_cuda_op_t op, bool src0_needs_f32) {
|
||||
ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) {
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
|
@ -1746,10 +1747,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
|
||||
// strides for iteration over dims 3 and 2
|
||||
const int64_t src0_stride = ne00 * ne01;
|
||||
const int64_t src1_stride = ne10 * ne11;
|
||||
const int64_t dst_stride = ne0 * ne1;
|
||||
const int64_t num_iters = ne02 * ne03;
|
||||
const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
|
||||
const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
|
||||
const int64_t src0_stride = ne00 * ne01 * stride_mod;
|
||||
const int64_t src1_stride = ne10 * ne11 * stride_mod;
|
||||
const int64_t dst_stride = ne0 * ne1 * stride_mod;
|
||||
|
||||
const size_t src0_ts = ggml_type_size(src0->type);
|
||||
const size_t src0_bs = ggml_blck_size(src0->type);
|
||||
|
@ -1763,6 +1765,8 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
|
||||
|
||||
const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
|
||||
const bool src1_stays_on_host = use_src1 && (
|
||||
dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
|
||||
|
||||
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
|
||||
|
@ -1772,13 +1776,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
char * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
|
||||
float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
|
||||
float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
|
||||
// asq = actual size quantized, asf = actual size float
|
||||
size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
if (!split && id != g_main_device) {
|
||||
|
@ -1824,7 +1828,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
|
||||
}
|
||||
|
||||
if (use_src1) {
|
||||
if (use_src1 && !src1_stays_on_host) {
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
||||
} else {
|
||||
|
@ -1838,26 +1842,29 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
|
||||
}
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
const int64_t i03_max = flatten_rows ? 1 : ne03;
|
||||
const int64_t i02_max = flatten_rows ? 1 : ne02;
|
||||
const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
|
||||
for (int64_t i03 = 0; i03 < i03_max; i03++) {
|
||||
const int64_t i13 = i03 % ne13;
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i02 = 0; i02 < i02_max; i02++) {
|
||||
const int64_t i12 = i02 % ne12;
|
||||
|
||||
const int64_t i0 = i03*ne02 + i02;
|
||||
const int64_t i0_offset_low = row_low/ne01;
|
||||
const int64_t i0_offset_high = row_high/ne01;
|
||||
const int64_t i0_offset_low = row_low/rows_per_iter;
|
||||
const int64_t i0_offset_high = row_high/rows_per_iter;
|
||||
|
||||
int64_t i01_low = 0;
|
||||
int64_t i01_high = ne01;
|
||||
int64_t i01_high = rows_per_iter;
|
||||
if (split) {
|
||||
if (i0 < i0_offset_low || i0 > i0_offset_high) {
|
||||
continue;
|
||||
}
|
||||
if (i0 == i0_offset_low) {
|
||||
i01_low = row_low % ne01;
|
||||
i01_low = row_low % rows_per_iter;
|
||||
}
|
||||
if (i0 == i0_offset_high) {
|
||||
i01_high = row_high % ne01;
|
||||
i01_high = row_high % rows_per_iter;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1866,7 +1873,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
// Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
|
||||
// The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
|
||||
GGML_ASSERT(i01_low == 0 || g_device_count > 1);
|
||||
GGML_ASSERT(i01_high == ne01 || g_device_count > 1);
|
||||
GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1);
|
||||
|
||||
const int64_t i01_diff = i01_high - i01_low;
|
||||
if (i01_diff == 0) {
|
||||
|
@ -1887,11 +1894,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
// for split tensors the data pointer needs to be rounded down
|
||||
// to the bin edge for i03, i02 bins beyond the first
|
||||
if (i0 - i0_offset_low > 0) {
|
||||
GGML_ASSERT(!flatten_rows);
|
||||
src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
|
||||
src0_ddf_i -= (row_low % ne01)*ne00;
|
||||
}
|
||||
if (i0 - i0_offset_low > 0) {
|
||||
dst_ddf_i -= (row_low % ne0)*ne1;
|
||||
dst_ddf_i -= (row_low % ne0)*ne1;
|
||||
}
|
||||
|
||||
// the main device memory buffer can be on VRAM scratch, with space for all partial results
|
||||
|
@ -1901,11 +1907,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
}
|
||||
|
||||
// copy src0, src1 to device if necessary
|
||||
if (use_src1) {
|
||||
if (use_src1 && !src1_stays_on_host) {
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_memcpy_src1));
|
||||
GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
|
||||
int64_t nrows1 = flatten_rows ? nrows0 : ne11;
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
|
||||
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
|
||||
if (id != g_main_device) {
|
||||
GGML_ASSERT(!flatten_rows);
|
||||
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
|
||||
src1_ddf_i_source += i11*src1_stride;
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
|
||||
|
@ -1991,22 +2000,22 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
|
||||
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
|
||||
}
|
||||
|
||||
void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
|
||||
}
|
||||
|
||||
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
|
@ -2236,12 +2245,12 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
|||
} else if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
|
||||
ggml_cuda_mul_mat_vec_nc_f16_f32(src0, src1, dst);
|
||||
}else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
|
||||
} else {
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
|
@ -2250,7 +2259,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
|||
|
||||
void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
@ -2304,17 +2313,17 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
|||
|
||||
void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results
|
||||
}
|
||||
|
||||
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue