ggml : add TODO's for F16/F32 mask/pos support in other backends

This commit is contained in:
Georgi Gerganov 2024-04-23 10:01:49 +03:00
parent c129369702
commit 3864eea4cb
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 17 additions and 1 deletions

View file

@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
for (int i = node_start; i < node_end; ++i) { for (int i = node_start; i < node_end; ++i) {
struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src0 = gf->nodes[i]->src[0];
struct ggml_tensor * src1 = gf->nodes[i]->src[1]; struct ggml_tensor * src1 = gf->nodes[i]->src[1];
struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
struct ggml_tensor * dst = gf->nodes[i]; struct ggml_tensor * dst = gf->nodes[i];
GGML_ASSERT(dst->data != nullptr); GGML_ASSERT(dst->data != nullptr);
@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
{ {
float scale; float scale;
memcpy(&scale, dst->op_params, sizeof(float)); memcpy(&scale, dst->op_params, sizeof(float));
#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
GGML_ASSERT(src2 == nullptr);
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
} break; } break;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:

View file

@ -14738,7 +14738,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
const ggml_tensor * src2 = dst->src[2];
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_x = ggml_nrows(src0);
@ -14754,7 +14759,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
float * src2_dd = nullptr; float * src2_dd = nullptr;
sycl_pool_alloc<float> src2_f; sycl_pool_alloc<float> src2_f;
ggml_tensor * src2 = dst->src[2];
const bool use_src2 = src2 != nullptr; const bool use_src2 = src2 != nullptr;
if (use_src2) { if (use_src2) {

View file

@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
} }
return nullptr; return nullptr;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32);
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32; return ctx->device->pipeline_soft_max_f32;
} }