metal : optimize softmax for C > 32
This commit is contained in:
parent
41d136b602
commit
56e45a239e
2 changed files with 20 additions and 5 deletions
|
@ -2217,29 +2217,35 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
for (int64_t p = tiisg; p < C; p += NW) {
|
for (int64_t p = tiisg; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
smax = simd_max(max(smax, s));
|
smax = max(smax, s);
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = max(M[j], s);
|
||||||
}
|
}
|
||||||
|
|
||||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
smax = simd_max(smax);
|
||||||
|
M[j] = simd_max(M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms;
|
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (tiisg == j) {
|
if (tiisg == j) {
|
||||||
ss[j*T + C + j] = ms;
|
ss[j*T + C + j] = ms;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// local sum
|
||||||
|
half ls = 0.0h;
|
||||||
|
|
||||||
for (int64_t p = tiisg; p < C; p += NW) {
|
for (int64_t p = tiisg; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j] + simd_sum(vs);
|
ls += vs;
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
S[j] = S[j]*ms + simd_sum(ls);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -572,9 +572,18 @@ struct test_case {
|
||||||
// duplicate the op
|
// duplicate the op
|
||||||
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
|
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
|
||||||
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
|
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
|
||||||
|
#if 1
|
||||||
for (int i = 1; i < n_runs; i++) {
|
for (int i = 1; i < n_runs; i++) {
|
||||||
gf->nodes[gf->n_nodes++] = out;
|
gf->nodes[gf->n_nodes++] = out;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
int n_nodes = gf->n_nodes;
|
||||||
|
for (int i = 1; i < n_runs; i++) {
|
||||||
|
for (int j = 0; j < n_nodes; j++) {
|
||||||
|
gf->nodes[gf->n_nodes++] = gf->nodes[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// calculate memory
|
// calculate memory
|
||||||
size_t mem = n_runs * op_size(out);
|
size_t mem = n_runs * op_size(out);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue