ggml : sync latest (SAM + SD operators, CUDA alibi) (#2709)

* ggml : sync latest (SAM + SD operators, CUDA alibi)

ggml-ci

* ggml : fix tabs
This commit is contained in:
Georgi Gerganov 2023-08-22 14:22:08 +03:00 committed by GitHub
parent 8e4364f2af
commit ef3f333d37
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 1090 additions and 61 deletions

View file

@ -1868,10 +1868,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
t04->grad = expand(gb, ggml_add_inplace(ctx0,
ggml_add_inplace(ctx0,