ggml : update ggml_soft_max_ext() CUDA, SYCL
This commit is contained in:
parent
7fdca3348c
commit
d0592d495d
4 changed files with 38 additions and 74 deletions
|
@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||||
static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
@ -23,7 +23,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
float slope = 0.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
|
@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
|
||||||
const int64_t ix = (int64_t)rowx*ncols + col;
|
const int64_t ix = (int64_t)rowx*ncols + col;
|
||||||
const int64_t iy = (int64_t)rowy*ncols + col;
|
const int64_t iy = (int64_t)rowy*ncols + col;
|
||||||
|
|
||||||
const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
|
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
|
||||||
|
|
||||||
vals[col] = val;
|
vals[col] = val;
|
||||||
max_val = max(max_val, val);
|
max_val = max(max_val, val);
|
||||||
|
@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
|
@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
|
||||||
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
||||||
switch (ncols_x) {
|
switch (ncols_x) {
|
||||||
case 32:
|
case 32:
|
||||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 1024:
|
case 1024:
|
||||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 2048:
|
case 2048:
|
||||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
case 4096:
|
case 4096:
|
||||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
const ggml_tensor * src2 = dst->src[2];
|
|
||||||
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *)src0->data;
|
||||||
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
|
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
|
||||||
|
@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
|
@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
// positions tensor
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||||
void * src2_d = nullptr;
|
|
||||||
|
|
||||||
const bool use_src2 = src2 != nullptr;
|
|
||||||
|
|
||||||
if (use_src2) {
|
|
||||||
src2_d = (void *)src2->data;
|
|
||||||
}
|
|
||||||
|
|
||||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
|
||||||
|
|
||||||
if (use_f16) {
|
if (use_f16) {
|
||||||
const half * src1_dd = (const half *)src1_d;
|
const half * src1_dd = (const half *)src1_d;
|
||||||
const half * src2_dd = (const half *)src2_d;
|
|
||||||
|
|
||||||
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||||
} else {
|
} else {
|
||||||
const float * src1_dd = (const float *)src1_d;
|
const float * src1_dd = (const float *)src1_d;
|
||||||
const float * src2_dd = (const float *)src2_d;
|
|
||||||
|
|
||||||
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1561,10 +1561,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||||
float scale;
|
float scale;
|
||||||
memcpy(&scale, dst->op_params, sizeof(float));
|
memcpy(&scale, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
|
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
|
||||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||||
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(src2 == nullptr);
|
|
||||||
|
|
||||||
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
||||||
} break;
|
} break;
|
||||||
|
|
|
@ -9416,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
|
||||||
|
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||||
static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
|
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
|
||||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||||
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
@ -9430,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
|
||||||
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||||
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||||
|
|
||||||
float slope = 0.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
|
@ -9455,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
|
||||||
const int ix = rowx*ncols + col;
|
const int ix = rowx*ncols + col;
|
||||||
const int iy = rowy*ncols + col;
|
const int iy = rowy*ncols + col;
|
||||||
|
|
||||||
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
|
||||||
|
|
||||||
vals[col] = val;
|
vals[col] = val;
|
||||||
max_val = sycl::max(max_val, val);
|
max_val = sycl::max(max_val, val);
|
||||||
|
@ -13017,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||||
static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
|
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
|
||||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||||
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
||||||
const size_t n_local_scratch, dpct::queue_ptr stream) {
|
const size_t n_local_scratch, dpct::queue_ptr stream) {
|
||||||
|
@ -13027,7 +13027,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
|
||||||
cgh.parallel_for(
|
cgh.parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
|
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||||
nrows_y, scale, max_bias, m0,
|
nrows_y, scale, max_bias, m0,
|
||||||
m1, n_head_log2, item_ct1,
|
m1, n_head_log2, item_ct1,
|
||||||
local_buf_acc.get_pointer());
|
local_buf_acc.get_pointer());
|
||||||
|
@ -13035,7 +13035,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
|
static void soft_max_f32_sycl(const float * x, const float * mask,
|
||||||
float * dst, const int ncols_x, const int nrows_x,
|
float * dst, const int ncols_x, const int nrows_x,
|
||||||
const int nrows_y, const float scale, const float max_bias,
|
const int nrows_y, const float scale, const float max_bias,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
|
@ -13057,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
|
||||||
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
||||||
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
||||||
if (ncols_x > max_block_size) {
|
if (ncols_x > max_block_size) {
|
||||||
soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
switch (ncols_x) {
|
switch (ncols_x) {
|
||||||
case 32:
|
case 32:
|
||||||
soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
break;
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
break;
|
break;
|
||||||
case 1024:
|
case 1024:
|
||||||
soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
break;
|
break;
|
||||||
case 2048:
|
case 2048:
|
||||||
soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
break;
|
break;
|
||||||
case 4096:
|
case 4096:
|
||||||
soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, n_local_scratch, stream);
|
block_dims, n_local_scratch, stream);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||||
max_bias, m0, m1, n_head_log2, block_nums,
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
block_dims, WARP_SIZE, stream);
|
block_dims, WARP_SIZE, stream);
|
||||||
}
|
}
|
||||||
|
@ -14675,12 +14675,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const ggml_tensor * src2 = dst->src[2];
|
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
|
||||||
|
|
||||||
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
|
|
||||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
|
@ -14692,25 +14689,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
||||||
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
// positions tensor
|
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
|
||||||
float * src2_dd = nullptr;
|
|
||||||
sycl_pool_alloc<float> src2_f;
|
|
||||||
|
|
||||||
const bool use_src2 = src2 != nullptr;
|
|
||||||
|
|
||||||
if (use_src2) {
|
|
||||||
const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
|
|
||||||
|
|
||||||
if (src2_on_device) {
|
|
||||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
|
|
||||||
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
|
||||||
} else {
|
|
||||||
src2_dd = src2_f.alloc(ggml_nelements(src2));
|
|
||||||
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
|
|
||||||
nrows_x, nrows_y, scale, max_bias, main_stream);
|
nrows_x, nrows_y, scale, max_bias, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return nullptr;
|
return nullptr;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16);
|
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_soft_max_f32;
|
return ctx->device->pipeline_soft_max_f32;
|
||||||
}
|
}
|
||||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue