Update clip.cpp
Added a comprehensive validation check to ensure that critical hyperparameters (hidden_size, n_head, n_layer, image_size, patch_size, projection_dim, eps) are not set to invalid or zero values. a throw statement to handle invalid hyperparameter values by raising an invalid_argument exception. Enhanced the error message within the catch block to log the specific exception message, aiding in debugging.
This commit is contained in:
parent
822b6322de
commit
aa9e72158b
1 changed files with 16 additions and 9 deletions
|
@ -1270,15 +1270,22 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
// load vision model
|
||||
auto & vision_model = new_clip->vision_model;
|
||||
auto & hparams = vision_model.hparams;
|
||||
hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "vision"));
|
||||
hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "vision"));
|
||||
hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision"));
|
||||
hparams.n_layer = get_u32(ctx, format(KEY_N_BLOCK, "vision"));
|
||||
hparams.image_size = get_u32(ctx, KEY_IMAGE_SIZE);
|
||||
hparams.patch_size = get_u32(ctx, KEY_PATCH_SIZE);
|
||||
hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision"));
|
||||
hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision"));
|
||||
|
||||
try{
|
||||
hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "vision"));
|
||||
hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "vision"));
|
||||
hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision"));
|
||||
hparams.n_layer = get_u32(ctx, format(KEY_N_BLOCK, "vision"));
|
||||
hparams.image_size = get_u32(ctx, KEY_IMAGE_SIZE);
|
||||
hparams.patch_size = get_u32(ctx, KEY_PATCH_SIZE);
|
||||
hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision"));
|
||||
hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision"));
|
||||
if (hparams.hidden_size == 0 || hparams.n_head == 0 || hparams.n_layer == 0 || hparams.n_intermediate || hparams.image_size == 0 || hparams.patch_size || hparams.projection_dim || hparams.eps) {
|
||||
throw std::invalid_argument("Invalid hyperparameter values");
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error while loading hyperparameters: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
|
||||
int n = gguf_get_arr_n(ctx, idx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue