Fix num_positions and embeddings initialization
This commit is contained in:
parent
77740fb3ad
commit
12536fda75
1 changed files with 10 additions and 8 deletions
|
@ -573,13 +573,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
struct ggml_tensor * embeddings = inp;
|
||||
if (ctx->has_class_embedding) {
|
||||
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
||||
ggml_set_name(embeddings, "embeddings");
|
||||
ggml_set_input(embeddings);
|
||||
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
|
||||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
|
||||
embeddings = ggml_acc(ctx0, embeddings, inp,
|
||||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
||||
}
|
||||
ggml_set_name(embeddings, "embeddings");
|
||||
ggml_set_input(embeddings);
|
||||
|
||||
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||
|
@ -1846,7 +1846,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
const int image_size = hparams.image_size;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
|
||||
const int num_positions = num_patches + 1;
|
||||
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
||||
|
||||
{
|
||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
||||
|
@ -1874,6 +1874,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
}
|
||||
|
||||
{
|
||||
if (ctx->has_class_embedding) {
|
||||
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
|
||||
|
||||
void* zero_mem = malloc(ggml_nbytes(embeddings));
|
||||
|
@ -1881,6 +1882,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
|
||||
free(zero_mem);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue