diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp index d39c2d183..c33cfa252 100644 --- a/ggml/src/ggml-sycl/wkv6.cpp +++ b/ggml/src/ggml-sycl/wkv6.cpp @@ -59,14 +59,15 @@ static void rwkv_wkv_f32_kernel( float y = 0; // Process in chunks of 4 for better vectorization + sycl::float4 k4, r4, tf4, td4, s4, kv4; #pragma unroll for (int j = 0; j < head_size; j += 4) { // Load data in vec4 chunks - sycl::float4 k4(_k[j], _k[j+1], _k[j+2], _k[j+3]); - sycl::float4 r4(_r[j], _r[j+1], _r[j+2], _r[j+3]); - sycl::float4 tf4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); - sycl::float4 td4(_td[j], _td[j+1], _td[j+2], _td[j+3]); - sycl::float4 s4(state[j], state[j+1], state[j+2], state[j+3]); + k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); // Compute key-value product sycl::float4 kv4 = k4 * _v;