Update clip.cpp
Added a comprehensive validation check to ensure that critical hyperparameters (hidden_size, n_head, n_layer, image_size, patch_size, projection_dim, eps) are not set to invalid or zero values. a throw statement to handle invalid hyperparameter values by raising an invalid_argument exception. Enhanced the error message within the catch block to log the specific exception message, aiding in debugging.
This commit is contained in:
parent
822b6322de
commit
aa9e72158b
1 changed files with 16 additions and 9 deletions
|
@ -1270,6 +1270,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
// load vision model
|
// load vision model
|
||||||
auto & vision_model = new_clip->vision_model;
|
auto & vision_model = new_clip->vision_model;
|
||||||
auto & hparams = vision_model.hparams;
|
auto & hparams = vision_model.hparams;
|
||||||
|
try{
|
||||||
hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "vision"));
|
hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "vision"));
|
||||||
hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "vision"));
|
hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "vision"));
|
||||||
hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision"));
|
hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision"));
|
||||||
|
@ -1278,7 +1279,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
hparams.patch_size = get_u32(ctx, KEY_PATCH_SIZE);
|
hparams.patch_size = get_u32(ctx, KEY_PATCH_SIZE);
|
||||||
hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision"));
|
hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision"));
|
||||||
hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision"));
|
hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision"));
|
||||||
|
if (hparams.hidden_size == 0 || hparams.n_head == 0 || hparams.n_layer == 0 || hparams.n_intermediate || hparams.image_size == 0 || hparams.patch_size || hparams.projection_dim || hparams.eps) {
|
||||||
|
throw std::invalid_argument("Invalid hyperparameter values");
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::cerr << "Error while loading hyperparameters: " << e.what() << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
|
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
|
||||||
int n = gguf_get_arr_n(ctx, idx);
|
int n = gguf_get_arr_n(ctx, idx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue