diff --git a/arch/riscv/net/bpf_jit.h b/arch/riscv/net/bpf_jit.h index a52a3dda18c3..d21c6c92a683 100644 --- a/arch/riscv/net/bpf_jit.h +++ b/arch/riscv/net/bpf_jit.h @@ -431,11 +431,21 @@ static inline u32 rv_mulhu(u8 rd, u8 rs1, u8 rs2) return rv_r_insn(1, rs2, rs1, 3, rd, 0x33); } +static inline u32 rv_div(u8 rd, u8 rs1, u8 rs2) +{ + return rv_r_insn(1, rs2, rs1, 4, rd, 0x33); +} + static inline u32 rv_divu(u8 rd, u8 rs1, u8 rs2) { return rv_r_insn(1, rs2, rs1, 5, rd, 0x33); } +static inline u32 rv_rem(u8 rd, u8 rs1, u8 rs2) +{ + return rv_r_insn(1, rs2, rs1, 6, rd, 0x33); +} + static inline u32 rv_remu(u8 rd, u8 rs1, u8 rs2) { return rv_r_insn(1, rs2, rs1, 7, rd, 0x33); @@ -776,11 +786,21 @@ static inline u32 rv_mulw(u8 rd, u8 rs1, u8 rs2) return rv_r_insn(1, rs2, rs1, 0, rd, 0x3b); } +static inline u32 rv_divw(u8 rd, u8 rs1, u8 rs2) +{ + return rv_r_insn(1, rs2, rs1, 4, rd, 0x3b); +} + static inline u32 rv_divuw(u8 rd, u8 rs1, u8 rs2) { return rv_r_insn(1, rs2, rs1, 5, rd, 0x3b); } +static inline u32 rv_remw(u8 rd, u8 rs1, u8 rs2) +{ + return rv_r_insn(1, rs2, rs1, 6, rd, 0x3b); +} + static inline u32 rv_remuw(u8 rd, u8 rs1, u8 rs2) { return rv_r_insn(1, rs2, rs1, 7, rd, 0x3b); diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c index 076619f4aa82..ba4535ac194f 100644 --- a/arch/riscv/net/bpf_jit_comp64.c +++ b/arch/riscv/net/bpf_jit_comp64.c @@ -1107,13 +1107,19 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, break; case BPF_ALU | BPF_DIV | BPF_X: case BPF_ALU64 | BPF_DIV | BPF_X: - emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx); + if (off) + emit(is64 ? rv_div(rd, rd, rs) : rv_divw(rd, rd, rs), ctx); + else + emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx); if (!is64 && !aux->verifier_zext) emit_zext_32(rd, ctx); break; case BPF_ALU | BPF_MOD | BPF_X: case BPF_ALU64 | BPF_MOD | BPF_X: - emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx); + if (off) + emit(is64 ? rv_rem(rd, rd, rs) : rv_remw(rd, rd, rs), ctx); + else + emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx); if (!is64 && !aux->verifier_zext) emit_zext_32(rd, ctx); break; @@ -1284,16 +1290,24 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, case BPF_ALU | BPF_DIV | BPF_K: case BPF_ALU64 | BPF_DIV | BPF_K: emit_imm(RV_REG_T1, imm, ctx); - emit(is64 ? rv_divu(rd, rd, RV_REG_T1) : - rv_divuw(rd, rd, RV_REG_T1), ctx); + if (off) + emit(is64 ? rv_div(rd, rd, RV_REG_T1) : + rv_divw(rd, rd, RV_REG_T1), ctx); + else + emit(is64 ? rv_divu(rd, rd, RV_REG_T1) : + rv_divuw(rd, rd, RV_REG_T1), ctx); if (!is64 && !aux->verifier_zext) emit_zext_32(rd, ctx); break; case BPF_ALU | BPF_MOD | BPF_K: case BPF_ALU64 | BPF_MOD | BPF_K: emit_imm(RV_REG_T1, imm, ctx); - emit(is64 ? rv_remu(rd, rd, RV_REG_T1) : - rv_remuw(rd, rd, RV_REG_T1), ctx); + if (off) + emit(is64 ? rv_rem(rd, rd, RV_REG_T1) : + rv_remw(rd, rd, RV_REG_T1), ctx); + else + emit(is64 ? rv_remu(rd, rd, RV_REG_T1) : + rv_remuw(rd, rd, RV_REG_T1), ctx); if (!is64 && !aux->verifier_zext) emit_zext_32(rd, ctx); break;