diff --git a/arch/arm64/kernel/signal.c b/arch/arm64/kernel/signal.c index 882f6d913508..3228b5a1dfe3 100644 --- a/arch/arm64/kernel/signal.c +++ b/arch/arm64/kernel/signal.c @@ -278,6 +278,9 @@ static int restore_sve_fpsimd_context(struct user_ctxs *user) if (__copy_from_user(&sve, user->sve, sizeof(sve))) return -EFAULT; + if (sve.head.size < sizeof(*user->sve)) + return -EINVAL; + if (sve.flags & SVE_SIG_FLAG_SM) { if (!system_supports_sme()) return -EINVAL; @@ -293,7 +296,7 @@ static int restore_sve_fpsimd_context(struct user_ctxs *user) if (sve.vl != vl) return -EINVAL; - if (sve.head.size <= sizeof(*user->sve)) { + if (sve.head.size == sizeof(*user->sve)) { clear_thread_flag(TIF_SVE); current->thread.svcr &= ~SVCR_SM_MASK; current->thread.fp_type = FP_STATE_FPSIMD; @@ -434,10 +437,13 @@ static int restore_za_context(struct user_ctxs *user) if (__copy_from_user(&za, user->za, sizeof(za))) return -EFAULT; + if (za.head.size < sizeof(*user->za)) + return -EINVAL; + if (za.vl != task_get_sme_vl(current)) return -EINVAL; - if (za.head.size <= sizeof(*user->za)) { + if (za.head.size == sizeof(*user->za)) { current->thread.svcr &= ~SVCR_ZA_MASK; return 0; } @@ -614,9 +620,6 @@ static int parse_user_sigframe(struct user_ctxs *user, if (user->fpsimd) goto invalid; - if (size < sizeof(*user->fpsimd)) - goto invalid; - user->fpsimd = (struct fpsimd_context __user *)head; break; @@ -631,9 +634,6 @@ static int parse_user_sigframe(struct user_ctxs *user, if (user->sve) goto invalid; - if (size < sizeof(*user->sve)) - goto invalid; - user->sve = (struct sve_context __user *)head; break; @@ -657,9 +657,6 @@ static int parse_user_sigframe(struct user_ctxs *user, if (user->za) goto invalid; - if (size < sizeof(*user->za)) - goto invalid; - user->za = (struct za_context __user *)head; break;