package cgroups

import (
	"bufio"
	"fmt"
	"io"
	"io/ioutil"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"syscall"

	"golang.org/x/sys/unix"

	specs "github.com/opencontainers/runtime-spec/specs-go"
)

func NewMemory(root string) *memoryController {
	return &memoryController{
		root: filepath.Join(root, string(Memory)),
	}
}

type memoryController struct {
	root string
}

func (m *memoryController) Name() Name {
	return Memory
}

func (m *memoryController) Path(path string) string {
	return filepath.Join(m.root, path)
}

func (m *memoryController) Create(path string, resources *specs.LinuxResources) error {
	if err := os.MkdirAll(m.Path(path), defaultDirPerm); err != nil {
		return err
	}
	if resources.Memory == nil {
		return nil
	}
	if resources.Memory.Kernel != nil {
		// Check if kernel memory is enabled
		// We have to limit the kernel memory here as it won't be accounted at all
		// until a limit is set on the cgroup and limit cannot be set once the
		// cgroup has children, or if there are already tasks in the cgroup.
		for _, i := range []int64{1, -1} {
			if err := ioutil.WriteFile(
				filepath.Join(m.Path(path), "memory.kmem.limit_in_bytes"),
				[]byte(strconv.FormatInt(i, 10)),
				defaultFilePerm,
			); err != nil {
				return checkEBUSY(err)
			}
		}
	}
	return m.set(path, getMemorySettings(resources))
}

func (m *memoryController) Update(path string, resources *specs.LinuxResources) error {
	if resources.Memory == nil {
		return nil
	}
	g := func(v *int64) bool {
		return v != nil && *v > 0
	}
	settings := getMemorySettings(resources)
	if g(resources.Memory.Limit) && g(resources.Memory.Swap) {
		// if the updated swap value is larger than the current memory limit set the swap changes first
		// then set the memory limit as swap must always be larger than the current limit
		current, err := readUint(filepath.Join(m.Path(path), "memory.limit_in_bytes"))
		if err != nil {
			return err
		}
		if current < uint64(*resources.Memory.Swap) {
			settings[0], settings[1] = settings[1], settings[0]
		}
	}
	return m.set(path, settings)
}

func (m *memoryController) Stat(path string, stats *Stats) error {
	f, err := os.Open(filepath.Join(m.Path(path), "memory.stat"))
	if err != nil {
		return err
	}
	defer f.Close()
	stats.Memory = &MemoryStat{}
	if err := m.parseStats(f, stats.Memory); err != nil {
		return err
	}
	for _, t := range []struct {
		module string
		entry  *MemoryEntry
	}{
		{
			module: "",
			entry:  &stats.Memory.Usage,
		},
		{
			module: "memsw",
			entry:  &stats.Memory.Swap,
		},
		{
			module: "kmem",
			entry:  &stats.Memory.Kernel,
		},
		{
			module: "kmem.tcp",
			entry:  &stats.Memory.KernelTCP,
		},
	} {
		for _, tt := range []struct {
			name  string
			value *uint64
		}{
			{
				name:  "usage_in_bytes",
				value: &t.entry.Usage,
			},
			{
				name:  "max_usage_in_bytes",
				value: &t.entry.Max,
			},
			{
				name:  "failcnt",
				value: &t.entry.Failcnt,
			},
			{
				name:  "limit_in_bytes",
				value: &t.entry.Limit,
			},
		} {
			parts := []string{"memory"}
			if t.module != "" {
				parts = append(parts, t.module)
			}
			parts = append(parts, tt.name)
			v, err := readUint(filepath.Join(m.Path(path), strings.Join(parts, ".")))
			if err != nil {
				return err
			}
			*tt.value = v
		}
	}
	return nil
}

func (m *memoryController) OOMEventFD(path string) (uintptr, error) {
	root := m.Path(path)
	f, err := os.Open(filepath.Join(root, "memory.oom_control"))
	if err != nil {
		return 0, err
	}
	defer f.Close()
	fd, _, serr := unix.RawSyscall(unix.SYS_EVENTFD2, 0, unix.FD_CLOEXEC, 0)
	if serr != 0 {
		return 0, serr
	}
	if err := writeEventFD(root, f.Fd(), fd); err != nil {
		unix.Close(int(fd))
		return 0, err
	}
	return fd, nil
}

func writeEventFD(root string, cfd, efd uintptr) error {
	f, err := os.OpenFile(filepath.Join(root, "cgroup.event_control"), os.O_WRONLY, 0)
	if err != nil {
		return err
	}
	_, err = f.WriteString(fmt.Sprintf("%d %d", efd, cfd))
	f.Close()
	return err
}

func (m *memoryController) parseStats(r io.Reader, stat *MemoryStat) error {
	var (
		raw  = make(map[string]uint64)
		sc   = bufio.NewScanner(r)
		line int
	)
	for sc.Scan() {
		if err := sc.Err(); err != nil {
			return err
		}
		key, v, err := parseKV(sc.Text())
		if err != nil {
			return fmt.Errorf("%d: %v", line, err)
		}
		raw[key] = v
		line++
	}
	stat.Cache = raw["cache"]
	stat.RSS = raw["rss"]
	stat.RSSHuge = raw["rss_huge"]
	stat.MappedFile = raw["mapped_file"]
	stat.Dirty = raw["dirty"]
	stat.Writeback = raw["writeback"]
	stat.PgPgIn = raw["pgpgin"]
	stat.PgPgOut = raw["pgpgout"]
	stat.PgFault = raw["pgfault"]
	stat.PgMajFault = raw["pgmajfault"]
	stat.InactiveAnon = raw["inactive_anon"]
	stat.ActiveAnon = raw["active_anon"]
	stat.InactiveFile = raw["inactive_file"]
	stat.ActiveFile = raw["active_file"]
	stat.Unevictable = raw["unevictable"]
	stat.HierarchicalMemoryLimit = raw["hierarchical_memory_limit"]
	stat.HierarchicalSwapLimit = raw["hierarchical_memsw_limit"]
	stat.TotalCache = raw["total_cache"]
	stat.TotalRSS = raw["total_rss"]
	stat.TotalRSSHuge = raw["total_rss_huge"]
	stat.TotalMappedFile = raw["total_mapped_file"]
	stat.TotalDirty = raw["total_dirty"]
	stat.TotalWriteback = raw["total_writeback"]
	stat.TotalPgPgIn = raw["total_pgpgin"]
	stat.TotalPgPgOut = raw["total_pgpgout"]
	stat.TotalPgFault = raw["total_pgfault"]
	stat.TotalPgMajFault = raw["total_pgmajfault"]
	stat.TotalInactiveAnon = raw["total_inactive_anon"]
	stat.TotalActiveAnon = raw["total_active_anon"]
	stat.TotalInactiveFile = raw["total_inactive_file"]
	stat.TotalActiveFile = raw["total_active_file"]
	stat.TotalUnevictable = raw["total_unevictable"]
	return nil
}

func (m *memoryController) set(path string, settings []memorySettings) error {
	for _, t := range settings {
		if t.value != nil {
			if err := ioutil.WriteFile(
				filepath.Join(m.Path(path), fmt.Sprintf("memory.%s", t.name)),
				[]byte(strconv.FormatInt(*t.value, 10)),
				defaultFilePerm,
			); err != nil {
				return err
			}
		}
	}
	return nil
}

type memorySettings struct {
	name  string
	value *int64
}

func getMemorySettings(resources *specs.LinuxResources) []memorySettings {
	mem := resources.Memory
	var swappiness *int64
	if mem.Swappiness != nil {
		v := int64(*mem.Swappiness)
		swappiness = &v
	}
	return []memorySettings{
		{
			name:  "limit_in_bytes",
			value: mem.Limit,
		},
		{
			name:  "memsw.limit_in_bytes",
			value: mem.Swap,
		},
		{
			name:  "kmem.limit_in_bytes",
			value: mem.Kernel,
		},
		{
			name:  "kmem.tcp.limit_in_bytes",
			value: mem.KernelTCP,
		},
		{
			name:  "oom_control",
			value: getOomControlValue(mem),
		},
		{
			name:  "swappiness",
			value: swappiness,
		},
	}
}

func checkEBUSY(err error) error {
	if pathErr, ok := err.(*os.PathError); ok {
		if errNo, ok := pathErr.Err.(syscall.Errno); ok {
			if errNo == unix.EBUSY {
				return fmt.Errorf(
					"failed to set memory.kmem.limit_in_bytes, because either tasks have already joined this cgroup or it has children")
			}
		}
	}
	return err
}

func getOomControlValue(mem *specs.LinuxMemory) *int64 {
	if mem.DisableOOMKiller != nil && *mem.DisableOOMKiller {
		i := int64(1)
		return &i
	}
	return nil
}