refactor: add platform

This commit is contained in:
soulteary 2023-01-09 23:14:52 +08:00
parent 8b887fc35d
commit 2b2997ce49
No known key found for this signature in database
GPG key ID: 8107DBA6BC84D986
5 changed files with 17 additions and 14 deletions

View file

@ -1,13 +1,13 @@
//go:build linux || windows //go:build linux || windows
// +build linux windows // +build linux windows
package main package platform
import ( import (
"errors" "errors"
"runtime" "runtime"
) )
func dropPrivileges(uid, gid int) error { func DropPrivileges(uid, gid int) error {
return errors.New("setuid and setgid not supported on " + runtime.GOOS) return errors.New("setuid and setgid not supported on " + runtime.GOOS)
} }

View file

@ -1,13 +1,13 @@
//go:build !windows && !linux //go:build !windows && !linux
// +build !windows,!linux // +build !windows,!linux
package main package platform
import ( import (
"syscall" "syscall"
) )
func dropPrivileges(uid, gid int) error { func DropPrivileges(uid, gid int) error {
err := syscall.Setgid(gid) err := syscall.Setgid(gid)
if err != nil { if err != nil {
return err return err

View file

@ -1,16 +1,18 @@
//go:build !windows //go:build !windows
// +build !windows // +build !windows
package main package platform
import ( import (
"log" "log"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/adnanh/webhook/internal/pidfile"
) )
func setupSignals() { func SetupSignals(signals chan os.Signal, reloadFn func(), pidFile *pidfile.PIDFile) {
log.Printf("setting up os signal watcher\n") log.Printf("setting up os signal watcher\n")
signals = make(chan os.Signal, 1) signals = make(chan os.Signal, 1)
@ -19,10 +21,10 @@ func setupSignals() {
signal.Notify(signals, syscall.SIGTERM) signal.Notify(signals, syscall.SIGTERM)
signal.Notify(signals, os.Interrupt) signal.Notify(signals, os.Interrupt)
go watchForSignals() go watchForSignals(signals, reloadFn, pidFile)
} }
func watchForSignals() { func watchForSignals(signals chan os.Signal, reloadFn func(), pidFile *pidfile.PIDFile) {
log.Println("os signal watcher ready") log.Println("os signal watcher ready")
for { for {
@ -30,11 +32,11 @@ func watchForSignals() {
switch sig { switch sig {
case syscall.SIGUSR1: case syscall.SIGUSR1:
log.Println("caught USR1 signal") log.Println("caught USR1 signal")
reloadAllHooks() reloadFn()
case syscall.SIGHUP: case syscall.SIGHUP:
log.Println("caught HUP signal") log.Println("caught HUP signal")
reloadAllHooks() reloadFn()
case os.Interrupt, syscall.SIGTERM: case os.Interrupt, syscall.SIGTERM:
log.Printf("caught %s signal; exiting\n", sig) log.Printf("caught %s signal; exiting\n", sig)

View file

@ -1,8 +1,8 @@
//go:build windows //go:build windows
// +build windows // +build windows
package main package platform
func setupSignals() { func SetupSignals() {
// NOOP: Windows doesn't have signals equivalent to the Unix world. // NOOP: Windows doesn't have signals equivalent to the Unix world.
} }

View file

@ -19,6 +19,7 @@ import (
"github.com/adnanh/webhook/internal/hook" "github.com/adnanh/webhook/internal/hook"
"github.com/adnanh/webhook/internal/middleware" "github.com/adnanh/webhook/internal/middleware"
"github.com/adnanh/webhook/internal/pidfile" "github.com/adnanh/webhook/internal/pidfile"
"github.com/adnanh/webhook/internal/platform"
chimiddleware "github.com/go-chi/chi/middleware" chimiddleware "github.com/go-chi/chi/middleware"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -144,7 +145,7 @@ func main() {
} }
if *setUID != 0 { if *setUID != 0 {
err := dropPrivileges(*setUID, *setGID) err := platform.DropPrivileges(*setUID, *setGID)
if err != nil { if err != nil {
logQueue = append(logQueue, fmt.Sprintf("error dropping privileges: %s", err)) logQueue = append(logQueue, fmt.Sprintf("error dropping privileges: %s", err))
// we'll bail out below // we'll bail out below
@ -197,7 +198,7 @@ func main() {
log.Println("version " + version + " starting") log.Println("version " + version + " starting")
// set os signal watcher // set os signal watcher
setupSignals() platform.SetupSignals(signals, reloadAllHooks, pidFile)
// load and parse hooks // load and parse hooks
for _, hooksFilePath := range hooksFiles { for _, hooksFilePath := range hooksFiles {