rwkv6: support avx2 avx512 armv8 armv9
This commit is contained in:
parent
f66c75a495
commit
b4254c5550
1 changed files with 316 additions and 0 deletions
316
ggml/src/ggml.c
316
ggml/src/ggml.c
|
@ -16635,6 +16635,320 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
size_t h_stride = C / H;
|
||||
size_t h_stride_2d = head_size * head_size;
|
||||
|
||||
#ifdef __AVX2__
|
||||
// AVX2 uses 256-bit vectors = 8 float32
|
||||
const int vec_size = 8;
|
||||
const size_t vec_count = head_size / vec_size;
|
||||
|
||||
for (size_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (size_t i = 0; i < head_size; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_i_offset = h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float k_val = k[t_h_i_offset];
|
||||
float r_val = r[t_h_i_offset];
|
||||
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||
float time_decay_val = time_decay[t_h_i_offset];
|
||||
|
||||
// Broadcast scalar values to vectors
|
||||
__m256 k_vec = _mm256_set1_ps(k_val);
|
||||
__m256 r_vec = _mm256_set1_ps(r_val);
|
||||
__m256 time_faaaa_vec = _mm256_set1_ps(time_faaaa_val);
|
||||
__m256 time_decay_vec = _mm256_set1_ps(time_decay_val);
|
||||
|
||||
// Vector processing for chunks of 8 floats
|
||||
for (size_t j = 0; j < vec_count; j++) {
|
||||
size_t base_j = j * vec_size;
|
||||
size_t t_h_j_offset = t_h_offset + base_j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
||||
|
||||
// Load 8 elements at once
|
||||
__m256 v_vec = _mm256_loadu_ps(&v[t_h_j_offset]);
|
||||
__m256 prev_state_vec = _mm256_loadu_ps(&state_prev[h_2d_i_j_offset]);
|
||||
__m256 dst_vec = _mm256_loadu_ps(&dst_data[t_h_j_offset]);
|
||||
|
||||
// Compute kv = v * k
|
||||
__m256 kv_vec = _mm256_mul_ps(v_vec, k_vec);
|
||||
|
||||
// Compute temp = kv * time_faaaa + prev_state
|
||||
__m256 kv_time_vec = _mm256_mul_ps(kv_vec, time_faaaa_vec);
|
||||
__m256 temp_vec = _mm256_add_ps(kv_time_vec, prev_state_vec);
|
||||
|
||||
// Update dst: dst += temp * r
|
||||
__m256 result_vec = _mm256_mul_ps(temp_vec, r_vec);
|
||||
dst_vec = _mm256_add_ps(dst_vec, result_vec);
|
||||
_mm256_storeu_ps(&dst_data[t_h_j_offset], dst_vec);
|
||||
|
||||
// Update state: state = prev_state * time_decay + kv
|
||||
__m256 decay_state_vec = _mm256_mul_ps(prev_state_vec, time_decay_vec);
|
||||
__m256 new_state_vec = _mm256_add_ps(decay_state_vec, kv_vec);
|
||||
_mm256_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec);
|
||||
}
|
||||
|
||||
// Handle remaining elements, this will not be used.
|
||||
for (size_t j = vec_count * vec_size; j < head_size; j++) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float v_val = v[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#elif __AVX512F__
|
||||
// AVX-512 uses 512-bit vectors = 16 float32
|
||||
const int vec_size = 16;
|
||||
const size_t vec_count = head_size / vec_size;
|
||||
const size_t vec_remain = head_size % vec_size;
|
||||
|
||||
for (size_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (size_t i = 0; i < head_size; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_i_offset = h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
// Load scalar values
|
||||
float k_val = k[t_h_i_offset];
|
||||
float r_val = r[t_h_i_offset];
|
||||
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||
float time_decay_val = time_decay[t_h_i_offset];
|
||||
|
||||
// Broadcast scalar values to ZMM registers (512-bit)
|
||||
__m512 k_vec = _mm512_set1_ps(k_val);
|
||||
__m512 r_vec = _mm512_set1_ps(r_val);
|
||||
__m512 time_faaaa_vec = _mm512_set1_ps(time_faaaa_val);
|
||||
__m512 time_decay_vec = _mm512_set1_ps(time_decay_val);
|
||||
|
||||
// Use prefetch to reduce cache misses
|
||||
#define PREFETCH_OFFSET 2
|
||||
if (i + PREFETCH_OFFSET < head_size) {
|
||||
_mm_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], _MM_HINT_T0);
|
||||
_mm_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], _MM_HINT_T0);
|
||||
}
|
||||
|
||||
// Vector processing for chunks of 16 floats
|
||||
for (size_t j = 0; j < vec_count; j++) {
|
||||
size_t base_j = j * vec_size;
|
||||
size_t t_h_j_offset = t_h_offset + base_j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
||||
|
||||
// Load 16 elements at once
|
||||
__m512 v_vec = _mm512_loadu_ps(&v[t_h_j_offset]);
|
||||
__m512 prev_state_vec = _mm512_loadu_ps(&state_prev[h_2d_i_j_offset]);
|
||||
__m512 dst_vec = _mm512_loadu_ps(&dst_data[t_h_j_offset]);
|
||||
|
||||
// Compute kv = v * k using FMA
|
||||
__m512 kv_vec = _mm512_mul_ps(v_vec, k_vec);
|
||||
|
||||
// Compute temp = kv * time_faaaa + prev_state using FMA
|
||||
__m512 temp_vec = _mm512_fmadd_ps(kv_vec, time_faaaa_vec, prev_state_vec);
|
||||
|
||||
// Update dst: dst += temp * r using FMA
|
||||
dst_vec = _mm512_fmadd_ps(temp_vec, r_vec, dst_vec);
|
||||
_mm512_storeu_ps(&dst_data[t_h_j_offset], dst_vec);
|
||||
|
||||
// Update state: state = prev_state * time_decay + kv using FMA
|
||||
__m512 new_state_vec = _mm512_fmadd_ps(prev_state_vec, time_decay_vec, kv_vec);
|
||||
_mm512_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec);
|
||||
}
|
||||
|
||||
// Handle remaining elements, this will not be used.
|
||||
for (size_t j = vec_count * vec_size; j < head_size; j++) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float v_val = v[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#elif __ARM_FEATURE_SVE
|
||||
// Get vector length for this CPU
|
||||
const size_t vec_size = svcntw();
|
||||
|
||||
for (size_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (size_t i = 0; i < head_size; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_i_offset = h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float k_val = k[t_h_i_offset];
|
||||
float r_val = r[t_h_i_offset];
|
||||
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||
float time_decay_val = time_decay[t_h_i_offset];
|
||||
|
||||
// Create predicate for active lanes
|
||||
svbool_t pg = svwhilelt_b32(0, head_size);
|
||||
|
||||
// Process vectors until done
|
||||
size_t j = 0;
|
||||
while (svptest_first(svptrue_b32(), pg)) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
// Load vectors
|
||||
svfloat32_t v_vec = svld1_f32(pg, &v[t_h_j_offset]);
|
||||
svfloat32_t prev_state_vec = svld1_f32(pg, &state_prev[h_2d_i_j_offset]);
|
||||
svfloat32_t dst_vec = svld1_f32(pg, &dst_data[t_h_j_offset]);
|
||||
|
||||
// Compute kv = v * k
|
||||
svfloat32_t kv_vec = svmul_n_f32_x(pg, v_vec, k_val);
|
||||
|
||||
// Compute temp = kv * time_faaaa + prev_state
|
||||
svfloat32_t temp_vec = svmad_n_f32_x(pg, kv_vec, time_faaaa_val, prev_state_vec);
|
||||
|
||||
// Update dst: dst += temp * r
|
||||
svfloat32_t result_vec = svmad_n_f32_x(pg, temp_vec, r_val, dst_vec);
|
||||
svst1_f32(pg, &dst_data[t_h_j_offset], result_vec);
|
||||
|
||||
// Update state: state = prev_state * time_decay + kv
|
||||
svfloat32_t new_state_vec = svmad_n_f32_x(pg, prev_state_vec, time_decay_val, kv_vec);
|
||||
svst1_f32(pg, &state_cur[h_2d_i_j_offset], new_state_vec);
|
||||
|
||||
j += vec_size;
|
||||
pg = svwhilelt_b32(j, head_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#elif __ARM_NEON
|
||||
// NEON uses 128-bit vectors = 4 float32s
|
||||
const int vec_size = 4;
|
||||
const size_t vec_count = head_size / vec_size;
|
||||
|
||||
for (size_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (size_t i = 0; i < head_size; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_i_offset = h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float k_val = k[t_h_i_offset];
|
||||
float r_val = r[t_h_i_offset];
|
||||
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||
float time_decay_val = time_decay[t_h_i_offset];
|
||||
|
||||
// Broadcast scalar values to vectors
|
||||
float32x4_t k_vec = vdupq_n_f32(k_val);
|
||||
float32x4_t r_vec = vdupq_n_f32(r_val);
|
||||
float32x4_t time_faaaa_vec = vdupq_n_f32(time_faaaa_val);
|
||||
float32x4_t time_decay_vec = vdupq_n_f32(time_decay_val);
|
||||
|
||||
// Use prefetch to reduce cache misses
|
||||
#ifdef __ARM_FEATURE_PREFETCH
|
||||
#define PREFETCH_OFFSET 2
|
||||
if (i + PREFETCH_OFFSET < head_size) {
|
||||
__builtin_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], 0, 3);
|
||||
__builtin_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], 0, 3);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Vector processing for chunks of 4 floats
|
||||
for (size_t j = 0; j < vec_count; j++) {
|
||||
size_t base_j = j * vec_size;
|
||||
size_t t_h_j_offset = t_h_offset + base_j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
||||
|
||||
// Load 4 elements at once
|
||||
float32x4_t v_vec = vld1q_f32(&v[t_h_j_offset]);
|
||||
float32x4_t prev_state_vec = vld1q_f32(&state_prev[h_2d_i_j_offset]);
|
||||
float32x4_t dst_vec = vld1q_f32(&dst_data[t_h_j_offset]);
|
||||
|
||||
// Compute kv = v * k
|
||||
float32x4_t kv_vec = vmulq_f32(v_vec, k_vec);
|
||||
|
||||
// Compute temp = kv * time_faaaa + prev_state using FMA
|
||||
#ifdef __ARM_FEATURE_FMA
|
||||
float32x4_t temp_vec = vfmaq_f32(prev_state_vec, kv_vec, time_faaaa_vec);
|
||||
// Update dst: dst += temp * r
|
||||
dst_vec = vfmaq_f32(dst_vec, temp_vec, r_vec);
|
||||
// Update state: state = prev_state * time_decay + kv
|
||||
float32x4_t new_state_vec = vfmaq_f32(kv_vec, prev_state_vec, time_decay_vec);
|
||||
#else
|
||||
float32x4_t kv_time = vmulq_f32(kv_vec, time_faaaa_vec);
|
||||
float32x4_t temp_vec = vaddq_f32(kv_time, prev_state_vec);
|
||||
float32x4_t result_vec = vmulq_f32(temp_vec, r_vec);
|
||||
dst_vec = vaddq_f32(dst_vec, result_vec);
|
||||
float32x4_t decay_state_vec = vmulq_f32(prev_state_vec, time_decay_vec);
|
||||
float32x4_t new_state_vec = vaddq_f32(decay_state_vec, kv_vec);
|
||||
#endif
|
||||
|
||||
vst1q_f32(&dst_data[t_h_j_offset], dst_vec);
|
||||
vst1q_f32(&state_cur[h_2d_i_j_offset], new_state_vec);
|
||||
}
|
||||
|
||||
// Handle remaining elements
|
||||
for (size_t j = vec_count * vec_size; j < head_size; j++) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float v_val = v[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#else
|
||||
// basically fused operations:
|
||||
// dst = r @ (time_faaaa * (k @ v) + state),
|
||||
// state = time_decay * state + (k @ v),
|
||||
|
@ -16674,7 +16988,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue