Update clip.cpp

Added a comprehensive validation check to ensure that critical hyperparameters (hidden_size, n_head, n_layer, image_size, patch_size, projection_dim, eps) are not set to invalid or zero values.

 a throw statement to handle invalid hyperparameter values by raising an invalid_argument exception.

Enhanced the error message within the catch block to log the specific exception message, aiding in debugging.
This commit is contained in:
Tejaakshaykumar 2024-09-14 18:54:26 +05:30 committed by GitHub
parent 822b6322de
commit aa9e72158b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1270,15 +1270,22 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// load vision model
auto & vision_model = new_clip->vision_model;
auto & hparams = vision_model.hparams;
hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "vision"));
hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "vision"));
hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision"));
hparams.n_layer = get_u32(ctx, format(KEY_N_BLOCK, "vision"));
hparams.image_size = get_u32(ctx, KEY_IMAGE_SIZE);
hparams.patch_size = get_u32(ctx, KEY_PATCH_SIZE);
hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision"));
hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision"));
try{
hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "vision"));
hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "vision"));
hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision"));
hparams.n_layer = get_u32(ctx, format(KEY_N_BLOCK, "vision"));
hparams.image_size = get_u32(ctx, KEY_IMAGE_SIZE);
hparams.patch_size = get_u32(ctx, KEY_PATCH_SIZE);
hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision"));
hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision"));
if (hparams.hidden_size == 0 || hparams.n_head == 0 || hparams.n_layer == 0 || hparams.n_intermediate || hparams.image_size == 0 || hparams.patch_size || hparams.projection_dim || hparams.eps) {
throw std::invalid_argument("Invalid hyperparameter values");
}
} catch (const std::exception& e) {
std::cerr << "Error while loading hyperparameters: " << e.what() << std::endl;
return false;
}
try {
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
int n = gguf_get_arr_n(ctx, idx);