package netlink

import (
	"syscall"

	"github.com/vishvananda/netlink/nl"
)

func selFromPolicy(sel *nl.XfrmSelector, policy *XfrmPolicy) {
	sel.Family = uint16(nl.FAMILY_V4)
	if policy.Dst != nil {
		sel.Family = uint16(nl.GetIPFamily(policy.Dst.IP))
		sel.Daddr.FromIP(policy.Dst.IP)
		prefixlenD, _ := policy.Dst.Mask.Size()
		sel.PrefixlenD = uint8(prefixlenD)
	}
	if policy.Src != nil {
		sel.Saddr.FromIP(policy.Src.IP)
		prefixlenS, _ := policy.Src.Mask.Size()
		sel.PrefixlenS = uint8(prefixlenS)
	}
	sel.Proto = uint8(policy.Proto)
	sel.Dport = nl.Swap16(uint16(policy.DstPort))
	sel.Sport = nl.Swap16(uint16(policy.SrcPort))
	if sel.Dport != 0 {
		sel.DportMask = ^uint16(0)
	}
	if sel.Sport != 0 {
		sel.SportMask = ^uint16(0)
	}
}

// XfrmPolicyAdd will add an xfrm policy to the system.
// Equivalent to: `ip xfrm policy add $policy`
func XfrmPolicyAdd(policy *XfrmPolicy) error {
	return pkgHandle.XfrmPolicyAdd(policy)
}

// XfrmPolicyAdd will add an xfrm policy to the system.
// Equivalent to: `ip xfrm policy add $policy`
func (h *Handle) XfrmPolicyAdd(policy *XfrmPolicy) error {
	return h.xfrmPolicyAddOrUpdate(policy, nl.XFRM_MSG_NEWPOLICY)
}

// XfrmPolicyUpdate will update an xfrm policy to the system.
// Equivalent to: `ip xfrm policy update $policy`
func XfrmPolicyUpdate(policy *XfrmPolicy) error {
	return pkgHandle.XfrmPolicyUpdate(policy)
}

// XfrmPolicyUpdate will update an xfrm policy to the system.
// Equivalent to: `ip xfrm policy update $policy`
func (h *Handle) XfrmPolicyUpdate(policy *XfrmPolicy) error {
	return h.xfrmPolicyAddOrUpdate(policy, nl.XFRM_MSG_UPDPOLICY)
}

func (h *Handle) xfrmPolicyAddOrUpdate(policy *XfrmPolicy, nlProto int) error {
	req := h.newNetlinkRequest(nlProto, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)

	msg := &nl.XfrmUserpolicyInfo{}
	selFromPolicy(&msg.Sel, policy)
	msg.Priority = uint32(policy.Priority)
	msg.Index = uint32(policy.Index)
	msg.Dir = uint8(policy.Dir)
	msg.Lft.SoftByteLimit = nl.XFRM_INF
	msg.Lft.HardByteLimit = nl.XFRM_INF
	msg.Lft.SoftPacketLimit = nl.XFRM_INF
	msg.Lft.HardPacketLimit = nl.XFRM_INF
	req.AddData(msg)

	tmplData := make([]byte, nl.SizeofXfrmUserTmpl*len(policy.Tmpls))
	for i, tmpl := range policy.Tmpls {
		start := i * nl.SizeofXfrmUserTmpl
		userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl])
		userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst)
		userTmpl.Saddr.FromIP(tmpl.Src)
		userTmpl.XfrmId.Proto = uint8(tmpl.Proto)
		userTmpl.XfrmId.Spi = nl.Swap32(uint32(tmpl.Spi))
		userTmpl.Mode = uint8(tmpl.Mode)
		userTmpl.Reqid = uint32(tmpl.Reqid)
		userTmpl.Aalgos = ^uint32(0)
		userTmpl.Ealgos = ^uint32(0)
		userTmpl.Calgos = ^uint32(0)
	}
	if len(tmplData) > 0 {
		tmpls := nl.NewRtAttr(nl.XFRMA_TMPL, tmplData)
		req.AddData(tmpls)
	}
	if policy.Mark != nil {
		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(policy.Mark))
		req.AddData(out)
	}

	_, err := req.Execute(syscall.NETLINK_XFRM, 0)
	return err
}

// XfrmPolicyDel will delete an xfrm policy from the system. Note that
// the Tmpls are ignored when matching the policy to delete.
// Equivalent to: `ip xfrm policy del $policy`
func XfrmPolicyDel(policy *XfrmPolicy) error {
	return pkgHandle.XfrmPolicyDel(policy)
}

// XfrmPolicyDel will delete an xfrm policy from the system. Note that
// the Tmpls are ignored when matching the policy to delete.
// Equivalent to: `ip xfrm policy del $policy`
func (h *Handle) XfrmPolicyDel(policy *XfrmPolicy) error {
	_, err := h.xfrmPolicyGetOrDelete(policy, nl.XFRM_MSG_DELPOLICY)
	return err
}

// XfrmPolicyList gets a list of xfrm policies in the system.
// Equivalent to: `ip xfrm policy show`.
// The list can be filtered by ip family.
func XfrmPolicyList(family int) ([]XfrmPolicy, error) {
	return pkgHandle.XfrmPolicyList(family)
}

// XfrmPolicyList gets a list of xfrm policies in the system.
// Equivalent to: `ip xfrm policy show`.
// The list can be filtered by ip family.
func (h *Handle) XfrmPolicyList(family int) ([]XfrmPolicy, error) {
	req := h.newNetlinkRequest(nl.XFRM_MSG_GETPOLICY, syscall.NLM_F_DUMP)

	msg := nl.NewIfInfomsg(family)
	req.AddData(msg)

	msgs, err := req.Execute(syscall.NETLINK_XFRM, nl.XFRM_MSG_NEWPOLICY)
	if err != nil {
		return nil, err
	}

	var res []XfrmPolicy
	for _, m := range msgs {
		if policy, err := parseXfrmPolicy(m, family); err == nil {
			res = append(res, *policy)
		} else if err == familyError {
			continue
		} else {
			return nil, err
		}
	}
	return res, nil
}

// XfrmPolicyGet gets a the policy described by the index or selector, if found.
// Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`.
func XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) {
	return pkgHandle.XfrmPolicyGet(policy)
}

// XfrmPolicyGet gets a the policy described by the index or selector, if found.
// Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`.
func (h *Handle) XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) {
	return h.xfrmPolicyGetOrDelete(policy, nl.XFRM_MSG_GETPOLICY)
}

// XfrmPolicyFlush will flush the policies on the system.
// Equivalent to: `ip xfrm policy flush`
func XfrmPolicyFlush() error {
	return pkgHandle.XfrmPolicyFlush()
}

// XfrmPolicyFlush will flush the policies on the system.
// Equivalent to: `ip xfrm policy flush`
func (h *Handle) XfrmPolicyFlush() error {
	req := h.newNetlinkRequest(nl.XFRM_MSG_FLUSHPOLICY, syscall.NLM_F_ACK)
	_, err := req.Execute(syscall.NETLINK_XFRM, 0)
	return err
}

func (h *Handle) xfrmPolicyGetOrDelete(policy *XfrmPolicy, nlProto int) (*XfrmPolicy, error) {
	req := h.newNetlinkRequest(nlProto, syscall.NLM_F_ACK)

	msg := &nl.XfrmUserpolicyId{}
	selFromPolicy(&msg.Sel, policy)
	msg.Index = uint32(policy.Index)
	msg.Dir = uint8(policy.Dir)
	req.AddData(msg)

	if policy.Mark != nil {
		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(policy.Mark))
		req.AddData(out)
	}

	resType := nl.XFRM_MSG_NEWPOLICY
	if nlProto == nl.XFRM_MSG_DELPOLICY {
		resType = 0
	}

	msgs, err := req.Execute(syscall.NETLINK_XFRM, uint16(resType))
	if err != nil {
		return nil, err
	}

	if nlProto == nl.XFRM_MSG_DELPOLICY {
		return nil, err
	}

	p, err := parseXfrmPolicy(msgs[0], FAMILY_ALL)
	if err != nil {
		return nil, err
	}

	return p, nil
}

func parseXfrmPolicy(m []byte, family int) (*XfrmPolicy, error) {
	msg := nl.DeserializeXfrmUserpolicyInfo(m)

	// This is mainly for the policy dump
	if family != FAMILY_ALL && family != int(msg.Sel.Family) {
		return nil, familyError
	}

	var policy XfrmPolicy

	policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD)
	policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS)
	policy.Proto = Proto(msg.Sel.Proto)
	policy.DstPort = int(nl.Swap16(msg.Sel.Dport))
	policy.SrcPort = int(nl.Swap16(msg.Sel.Sport))
	policy.Priority = int(msg.Priority)
	policy.Index = int(msg.Index)
	policy.Dir = Dir(msg.Dir)

	attrs, err := nl.ParseRouteAttr(m[msg.Len():])
	if err != nil {
		return nil, err
	}

	for _, attr := range attrs {
		switch attr.Attr.Type {
		case nl.XFRMA_TMPL:
			max := len(attr.Value)
			for i := 0; i < max; i += nl.SizeofXfrmUserTmpl {
				var resTmpl XfrmPolicyTmpl
				tmpl := nl.DeserializeXfrmUserTmpl(attr.Value[i : i+nl.SizeofXfrmUserTmpl])
				resTmpl.Dst = tmpl.XfrmId.Daddr.ToIP()
				resTmpl.Src = tmpl.Saddr.ToIP()
				resTmpl.Proto = Proto(tmpl.XfrmId.Proto)
				resTmpl.Mode = Mode(tmpl.Mode)
				resTmpl.Spi = int(nl.Swap32(tmpl.XfrmId.Spi))
				resTmpl.Reqid = int(tmpl.Reqid)
				policy.Tmpls = append(policy.Tmpls, resTmpl)
			}
		case nl.XFRMA_MARK:
			mark := nl.DeserializeXfrmMark(attr.Value[:])
			policy.Mark = new(XfrmMark)
			policy.Mark.Value = mark.Value
			policy.Mark.Mask = mark.Mask
		}
	}

	return &policy, nil
}