metal : use F32 attention accumulators
This commit is contained in:
parent
fa9e8c6689
commit
c16a7c2688
3 changed files with 81 additions and 93 deletions
15
ggml-metal.m
15
ggml-metal.m
|
@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
||||
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
|
@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||
|
||||
int64_t nsg = 1;
|
||||
|
@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
}
|
||||
nsg /= 2;
|
||||
|
||||
// require power of 2
|
||||
//{
|
||||
// int64_t nsgm = 1;
|
||||
// while (nsgm < nsg) {
|
||||
// nsgm *= 2;
|
||||
// }
|
||||
// GGML_ASSERT(nsg == nsgm);
|
||||
//}
|
||||
|
||||
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
|
|
156
ggml-metal.metal
156
ggml-metal.metal
|
@ -2169,12 +2169,13 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
const short TF = T/2; // shared memory size per query in (float)
|
||||
const short T4 = T/4; // shared memory size per query in (half4)
|
||||
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
simdgroup_half8x8 lo[Q8][D8];
|
||||
|
@ -2202,15 +2203,15 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
// zero out shared memory SH
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
for (short i = tiisg; i < SH; i += NW) {
|
||||
ss[j*T + i] = 0.0h;
|
||||
ss[j*TF + i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
half S[Q] = { [0 ... Q-1] = 0.0h };
|
||||
half M[Q] = { [0 ... Q-1] = -INFINITY };
|
||||
float S[Q] = { [0 ... Q-1] = 0.0h };
|
||||
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
||||
|
||||
// assume K and V are same shape
|
||||
const short ne22 = ne12;
|
||||
|
@ -2248,7 +2249,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||
|
||||
// prepare diagonal scale matrix
|
||||
simdgroup_half8x8 mscale(scale);
|
||||
simdgroup_float8x8 mscale(scale);
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
|
@ -2261,9 +2262,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
// Q*K^T
|
||||
{
|
||||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_half8x8 mqk[Q8];
|
||||
simdgroup_float8x8 mqk[Q8];
|
||||
for (short j = 0; j < Q8; ++j) {
|
||||
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
|
||||
mqk[j] = make_filled_simdgroup_matrix<float, 8>(0.h);
|
||||
}
|
||||
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
@ -2283,48 +2284,48 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
|
||||
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
||||
|
||||
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
|
||||
simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// used to detect blocks full of -INF
|
||||
half smax = -INFINITY;
|
||||
float smax = -INFINITY;
|
||||
|
||||
// online softmax
|
||||
if (C == 32) {
|
||||
half ms[Q];
|
||||
float ms[Q];
|
||||
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
const short p = tiisg;
|
||||
|
||||
const half m = M[j];
|
||||
const half s = ss[j*T + p];
|
||||
const float m = M[j];
|
||||
const float s = ss[j*TF + p];
|
||||
|
||||
smax = simd_max(max(smax, s));
|
||||
M[j] = simd_max(max(M[j], s));
|
||||
|
||||
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||
ms[j] = exp(m - M[j]);
|
||||
const float vs = exp(s - M[j]);
|
||||
|
||||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*T + p] = vs;
|
||||
ss[j*TF + p] = vs;
|
||||
}
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (tiisg < Q) {
|
||||
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
||||
ss[tiisg*TF + C + tiisg] = ms[tiisg];
|
||||
}
|
||||
} else {
|
||||
half ms[Q];
|
||||
float ms[Q];
|
||||
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
const half m = M[j];
|
||||
const float m = M[j];
|
||||
|
||||
for (short p = tiisg; p < C; p += NW) {
|
||||
const half s = ss[j*T + p];
|
||||
const float s = ss[j*TF + p];
|
||||
|
||||
smax = max(smax, s);
|
||||
M[j] = max(M[j], s);
|
||||
|
@ -2333,20 +2334,20 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
smax = simd_max(smax);
|
||||
M[j] = simd_max(M[j]);
|
||||
|
||||
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||
ms[j] = exp(m - M[j]);
|
||||
|
||||
// local sum
|
||||
half ls = 0.0h;
|
||||
float ls = 0.0h;
|
||||
|
||||
for (short p = tiisg; p < C; p += NW) {
|
||||
const half s = ss[j*T + p];
|
||||
const float s = ss[j*TF + p];
|
||||
|
||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||
const float vs = exp(s - M[j]);
|
||||
|
||||
ls += vs;
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*T + p] = vs;
|
||||
ss[j*TF + p] = vs;
|
||||
}
|
||||
|
||||
S[j] = S[j]*ms[j] + simd_sum(ls);
|
||||
|
@ -2354,7 +2355,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (tiisg < Q) {
|
||||
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
||||
ss[tiisg*TF + C + tiisg] = ms[tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2365,8 +2366,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
|
||||
// O = diag(ms)*O
|
||||
for (short j = 0; j < Q8; ++j) {
|
||||
simdgroup_half8x8 mm;
|
||||
simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
|
||||
simdgroup_float8x8 mm;
|
||||
simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_multiply(lo[j][i], mm, lo[j][i]);
|
||||
|
@ -2383,8 +2384,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
||||
|
||||
for (short j = 0; j < Q8; ++j) {
|
||||
simdgroup_half8x8 mv;
|
||||
simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
|
||||
simdgroup_float8x8 mv;
|
||||
simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
|
||||
}
|
||||
|
@ -2396,16 +2397,16 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
if (tiisg == 0) {
|
||||
ss[j*T + 0] = S[j];
|
||||
ss[j*T + 1] = M[j];
|
||||
ss[j*TF + 0] = S[j];
|
||||
ss[j*TF + 1] = M[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reduce the warps sequentially
|
||||
for (short sg = 1; sg < nsg; ++sg) {
|
||||
half S = { 0.0h };
|
||||
half M = { -INFINITY };
|
||||
float S = { 0.0h };
|
||||
float M = { -FLT_MAX/2 };
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
@ -2423,36 +2424,36 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*SH + 0];
|
||||
const float S0 = ss[j*TF + 0];
|
||||
const float S1 = ss[j*TF + sg*SH + 0];
|
||||
|
||||
const half M0 = ss[j*T + 1];
|
||||
const half M1 = ss[j*T + sg*SH + 1];
|
||||
const float M0 = ss[j*TF + 1];
|
||||
const float M1 = ss[j*TF + sg*SH + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
|
||||
const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M);
|
||||
const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M);
|
||||
const float ms0 = exp(M0 - M);
|
||||
const float ms1 = exp(M1 - M);
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[j*T + 0] = S;
|
||||
ss[j*T + 1] = M;
|
||||
ss[j*TF + 0] = S;
|
||||
ss[j*TF + 1] = M;
|
||||
|
||||
ss[j*T + C + j ] = ms0;
|
||||
ss[j*T + C + j + sg*SH] = ms1;
|
||||
ss[j*TF + C + j ] = ms0;
|
||||
ss[j*TF + C + j + sg*SH] = ms1;
|
||||
}
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
for (short j = 0; j < Q8; ++j) {
|
||||
simdgroup_half8x8 t;
|
||||
simdgroup_half8x8 ms0;
|
||||
simdgroup_half8x8 ms1;
|
||||
simdgroup_float8x8 ms0;
|
||||
simdgroup_float8x8 ms1;
|
||||
|
||||
simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false);
|
||||
simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false);
|
||||
simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false);
|
||||
simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
|
||||
|
@ -2478,7 +2479,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const half S = ss[j*T + 0];
|
||||
const float S = ss[j*TF + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
||||
|
@ -2494,8 +2495,6 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f
|
|||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
|
||||
|
||||
#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||
|
||||
template<int64_t D, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_vec_f16(
|
||||
device const char * q,
|
||||
|
@ -2539,18 +2538,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
const short iq1 = tgpig[0];
|
||||
|
||||
const short D4 = D/4;
|
||||
const short D8 = D/8;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (C + 1); // shared memory per simdgroup in (half)
|
||||
|
||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||
const short T4 = T/4; // shared memory size per query in (half4)
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
||||
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
half4 lo[D4/NW];
|
||||
|
@ -2579,8 +2576,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
half S = { 0.0h };
|
||||
half M = { -HALF_MAX_HALF };
|
||||
float S = { 0.0h };
|
||||
float M = { -FLT_MAX/2 };
|
||||
|
||||
// assume K and V are same shape
|
||||
const short ne22 = ne12;
|
||||
|
@ -2628,7 +2625,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
{
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
half4 mqk = { 0.0h };
|
||||
float4 mqk = { 0.0h };
|
||||
|
||||
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
|
@ -2642,7 +2639,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
mk[2] = pk4[i + 2*(nb11/8)];
|
||||
mk[3] = pk4[i + 3*(nb11/8)];
|
||||
|
||||
mqk += mq[i] * mk;
|
||||
mqk += (float4) (mq[i] * mk);
|
||||
}
|
||||
|
||||
// reduce the results from the threads in the simdgroup
|
||||
|
@ -2654,7 +2651,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
|
||||
// mqk = mqk*scale + mask
|
||||
if (tiisg == 0) {
|
||||
half4 mm = mp4[ic/4 + cc];
|
||||
float4 mm = (float4) mp4[ic/4 + cc];
|
||||
mqk = mqk*scale + mm;
|
||||
|
||||
ss4[cc] = mqk;
|
||||
|
@ -2666,13 +2663,13 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
{
|
||||
const short p = tiisg;
|
||||
|
||||
const half m = M;
|
||||
const half s = ss[p];
|
||||
const float m = M;
|
||||
const float s = ss[p];
|
||||
|
||||
M = simd_max(max(M, s));
|
||||
|
||||
const half ms = exp(m - M);
|
||||
const half vs = exp(s - M);
|
||||
const float ms = exp(m - M);
|
||||
const float vs = exp(s - M);
|
||||
|
||||
S = S*ms + simd_sum(vs);
|
||||
|
||||
|
@ -2696,6 +2693,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
#pragma unroll
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
|
||||
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
||||
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
||||
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
||||
|
@ -2724,18 +2722,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
// parallel reduce
|
||||
for (short r = nsg/2; r > 0; r >>= 1) {
|
||||
if (sgitg < r) {
|
||||
const half S0 = ss[ 0];
|
||||
const half S1 = ss[r*SH + 0];
|
||||
const float S0 = ss[ 0];
|
||||
const float S1 = ss[r*SH + 0];
|
||||
|
||||
const half M0 = ss[ 1];
|
||||
const half M1 = ss[r*SH + 1];
|
||||
const float M0 = ss[ 1];
|
||||
const float M1 = ss[r*SH + 1];
|
||||
|
||||
const half M = max(M0, M1);
|
||||
const float M = max(M0, M1);
|
||||
|
||||
const half ms0 = exp(M0 - M);
|
||||
const half ms1 = exp(M1 - M);
|
||||
const float ms0 = exp(M0 - M);
|
||||
const float ms1 = exp(M1 - M);
|
||||
|
||||
const half S = S0*ms0 + S1*ms1;
|
||||
const float S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[0] = S;
|
||||
|
@ -2756,7 +2754,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
const half S = ss[0];
|
||||
const float S = ss[0];
|
||||
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
|
|
3
ggml.c
3
ggml.c
|
@ -14882,12 +14882,13 @@ static void ggml_compute_forward_flash_attn_ext(
|
|||
struct ggml_tensor * dst) {
|
||||
switch (dst->op_params[1]) {
|
||||
case GGML_PREC_DEFAULT:
|
||||
case GGML_PREC_F32:
|
||||
{
|
||||
// uses F32 accumulators
|
||||
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
// TODO: implement F32 precision
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue