softmax: handle SYCL exceptions and add debug logs
This commit is contained in:
parent
bba4b66a81
commit
6dbb7ac827
4 changed files with 21 additions and 8 deletions
|
@ -2752,7 +2752,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|||
ggml_sycl_group_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
ggml_sycl_op_concat(ctx, dst);
|
||||
ggml_sycl_concat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_UPSCALE:
|
||||
ggml_sycl_upscale(ctx, dst);
|
||||
|
@ -2817,7 +2817,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|||
ggml_sycl_diag_mask_inf(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
ggml_sycl_op_soft_max(ctx, dst);
|
||||
ggml_sycl_softmax(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
ggml_sycl_rope(ctx, dst);
|
||||
|
|
|
@ -224,7 +224,7 @@ static void soft_max_f32_sycl(const float * x, const T * mask,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
static void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
@ -249,13 +249,26 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|||
|
||||
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
|
||||
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
|
||||
GGML_SYCL_DEBUG("%s: Mask precision: F16\n", __func__);
|
||||
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
|
||||
main_stream, ctx.device);
|
||||
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
|
||||
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
||||
GGML_SYCL_DEBUG("%s: Mask precision: F32\n", __func__);
|
||||
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
} else {
|
||||
/* mask unavailable */
|
||||
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
GGML_SYCL_DEBUG("%s: No mask supplied\n", __func__);
|
||||
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream,
|
||||
ctx.device);
|
||||
}
|
||||
} catch (const sycl::exception & exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
void ggml_sycl_softmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_soft_max(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
|
|
@ -15,6 +15,6 @@
|
|||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
|
||||
void ggml_sycl_softmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_SOFTMAX_HPP
|
||||
|
|
|
@ -27,7 +27,7 @@ static void sum_rows_f32_sycl(const float * x, float * dst, const int ncols, con
|
|||
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||
|
||||
const int64_t ne = ggml_nelements(dst->src[0]);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
|
@ -43,7 +43,7 @@ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|||
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||
|
||||
const int64_t ncols = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
|
@ -68,5 +68,5 @@ void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_sum_rows(ctx, dst);
|
||||
GML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue