diff --git a/cg.go b/cg.go index a2ec441..b30dd00 100644 --- a/cg.go +++ b/cg.go @@ -13,6 +13,7 @@ import "C" import ( "errors" "runtime" + "strconv" "unsafe" ) @@ -561,3 +562,22 @@ func _err(num C.int) error { // There's a lot. We'll create them as they come return errors.New(C.GoString(C.cgroup_strerror(num))) } + +// simple helpers to get UID or GID from a given string +func stringToUID(uidStr string) UID { + intVal, err := strconv.Atoi(uidStr) + if err != nil { + return UID(0) + } + + return UID(intVal) +} + +func stringToGID(gidStr string) GID { + intVal, err := strconv.Atoi(gidStr) + if err != nil { + return GID(0) + } + + return GID(intVal) +} diff --git a/cg_test.go b/cg_test.go new file mode 100644 index 0000000..f0755ab --- /dev/null +++ b/cg_test.go @@ -0,0 +1,41 @@ +package cgroup + +import ( + "os/user" + "testing" +) + +func TestUidGid(t *testing.T) { + Init() + + var wantTaskUid, wantCtrlUid UID + var wantTaskGid, wantCtrlGid GID + + curUser, err := user.Current() + if err == nil { + wantTaskUid = stringToUID(curUser.Uid) + wantTaskGid = stringToGID(curUser.Gid) + wantCtrlUid = stringToUID(curUser.Uid) + wantCtrlGid = stringToGID(curUser.Gid) + } else { + t.Logf("cannot get the current user. fall back to 0.\n") + wantTaskUid, wantTaskGid, wantCtrlUid, wantCtrlGid = 0, 0, 0, 0 + } + + cg := NewCgroup("test_cgroup") + if err := cg.SetUIDGID(wantTaskUid, wantTaskGid, wantCtrlUid, wantCtrlGid); err != nil { + t.Fatalf("cannot set cgroup uids/gids: %v\n", err) + } + + gotTaskUid, gotTaskGid, gotCtrlUid, gotCtrlGid, err := cg.GetUIDGID() + if err != nil { + t.Fatalf("cannot get cgroup uids/gids: %v\n", err) + } + + if wantTaskUid != gotTaskUid || wantTaskGid != gotTaskGid || + wantCtrlUid != gotCtrlUid || wantCtrlGid != gotCtrlGid { + t.Fatalf("wanted (%d,%d,%d,%d), got (%d,%d,%d,%d)\n", + wantTaskUid, wantTaskGid, wantCtrlUid, wantCtrlGid, + gotTaskUid, gotTaskGid, gotCtrlUid, gotCtrlGid) + } +}