This commit is contained in:
slaren 2023-07-25 15:50:57 +02:00
parent 8a927cf487
commit e25e15c9c5

2
ggml.c
View file

@ -14777,7 +14777,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN:
{ {
const int32_t t = ggml_get_i32_1d(tensor->src[3], 0); const int32_t t = ggml_get_op_params_i32(tensor, 0);
GGML_ASSERT(t == 0 || t == 1); GGML_ASSERT(t == 0 || t == 1);
const bool masked = t != 0; const bool masked = t != 0;
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);