From cba03408710588b3f8a7abac8bc72c1088f78c28 Mon Sep 17 00:00:00 2001 From: Tejaakshaykumar <147340353+Tejaakshaykumar@users.noreply.github.com> Date: Tue, 17 Sep 2024 15:46:59 +0530 Subject: [PATCH] Refactored error handling for hyperparameter validation in clip.cpp Removed the try-catch block and throw statements Added validation checks for all hyperparameters, ensuring they are not set to invalid or zero values. Used fprintf to log error messages when invalid hyperparameters are encountered. --- examples/llava/clip.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 7155832b4..3327cb266 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -1270,7 +1270,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { // load vision model auto & vision_model = new_clip->vision_model; auto & hparams = vision_model.hparams; - try{ hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "vision")); hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "vision")); hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision")); @@ -1279,12 +1278,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { hparams.patch_size = get_u32(ctx, KEY_PATCH_SIZE); hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision")); hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision")); - if (hparams.hidden_size == 0 || hparams.n_head == 0 || hparams.n_layer == 0 || hparams.n_intermediate || hparams.image_size == 0 || hparams.patch_size || hparams.projection_dim || hparams.eps) { - throw std::invalid_argument("Invalid hyperparameter values"); + if (hparams.hidden_size == 0 || hparams.n_head == 0 || hparams.n_layer == 0 || hparams.n_intermediate == 0 || hparams.image_size == 0 || hparams.patch_size == 0 || hparams.projection_dim == 0 || hparams.eps == 0) { + fprintf(stderr, "Error: Invalid hyperparameter values\n"); + return false; } - } catch (const std::exception& e) { - fprintf(stderr, "Error while loading hyperparameters: %s\n", e.what()); - return false; } try { int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);