put the declaration outside the loop

This commit is contained in:
Zhiyuan Li 2024-11-05 02:45:40 +11:00
parent 6a1e977e34
commit a749ba7701

View file

@ -59,14 +59,15 @@ static void rwkv_wkv_f32_kernel(
float y = 0;
// Process in chunks of 4 for better vectorization
sycl::float4 k4, r4, tf4, td4, s4, kv4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
// Load data in vec4 chunks
sycl::float4 k4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
sycl::float4 r4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
sycl::float4 tf4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
sycl::float4 td4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
sycl::float4 s4(state[j], state[j+1], state[j+2], state[j+3]);
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
// Compute key-value product
sycl::float4 kv4 = k4 * _v;