diff --git a/src/frontend/money.go b/src/frontend/money.go deleted file mode 100644 index a97e1d3..0000000 --- a/src/frontend/money.go +++ /dev/null @@ -1,50 +0,0 @@ -package main - -import ( - "math" - - pb "frontend/genproto" -) - -// TODO(ahmetb): any logic below is flawed because I just realized we have no -// way of representing amounts like 17.07 because Fractional cannot store 07. -func multMoney(m pb.MoneyAmount, n uint32) pb.MoneyAmount { - out := m - for n > 1 { - out = sum(out, m) - n-- - } - return out -} - -func sum(m1, m2 pb.MoneyAmount) pb.MoneyAmount { - // TODO(ahmetb) this is copied from ./checkoutservice/money.go, find a - // better mult function. - f1, f2 := float64(m1.Fractional), float64(m2.Fractional) - lg1 := math.Max(1, math.Ceil(math.Log10(f1))) - if f1 == math.Pow(10, lg1) { - lg1++ - } - lg2 := math.Max(1, math.Ceil(math.Log10(f2))) - if f2 == math.Pow(10, lg2) { - lg2++ - } - lgMax := math.Max(lg1, lg2) - - dSum := m1.Decimal + m2.Decimal - o1 := f1 * math.Pow(10, lgMax-lg1) - o2 := f2 * math.Pow(10, lgMax-lg2) - fSum := o1 + o2 - if fSum >= math.Pow(10, lgMax) { - fSum -= math.Pow(10, lgMax) - dSum++ - } - - for int(fSum)%10 == 0 && fSum != 0 { - fSum = float64(int(fSum) / 10) - } - - return pb.MoneyAmount{ - Decimal: dSum, - Fractional: uint32(fSum)} -} diff --git a/src/frontend/money/money.go b/src/frontend/money/money.go new file mode 100644 index 0000000..bc1856a --- /dev/null +++ b/src/frontend/money/money.go @@ -0,0 +1,117 @@ +package money + +import ( + "errors" + pb "frontend/genproto" +) + +const ( + nanosMin = -999999999 + nanosMax = +999999999 + nanosMod = 1000000000 +) + +var ( + ErrInvalidValue = errors.New("one of the specified money values is invalid") + ErrMismatchingCurrency = errors.New("mismatching currency codes") +) + +// IsValid checks if specified value has a valid units/nanos signs and ranges. +func IsValid(m pb.Money) bool { + return signMatches(m) && validNanos(m.GetNanos()) +} + +func signMatches(m pb.Money) bool { + return m.GetNanos() == 0 || m.GetUnits() == 0 || (m.GetNanos() < 0) == (m.GetUnits() < 0) +} + +func validNanos(nanos int32) bool { return nanosMin <= nanos && nanos <= nanosMax } + +// IsZero returns true if the specified money value is equal to zero. +func IsZero(m pb.Money) bool { return m.GetUnits() == 0 && m.GetNanos() == 0 } + +// IsPositive returns true if the specified money value is valid and is +// positive. +func IsPositive(m pb.Money) bool { + return IsValid(m) && m.GetUnits() > 0 || (m.GetUnits() == 0 && m.GetNanos() > 0) +} + +// IsNegative returns true if the specified money value is valid and is +// negative. +func IsNegative(m pb.Money) bool { + return IsValid(m) && m.GetUnits() < 0 || (m.GetUnits() == 0 && m.GetNanos() < 0) +} + +// AreSameCurrency returns true if values l and r have a currency code and +// they are the same values. +func AreSameCurrency(l, r pb.Money) bool { + return l.GetCurrencyCode() == r.GetCurrencyCode() && l.GetCurrencyCode() != "" +} + +// AreEquals returns true if values l and r are the equal, including the +// currency. This does not check validity of the provided values. +func AreEquals(l, r pb.Money) bool { + return l.GetCurrencyCode() == r.GetCurrencyCode() && + l.GetUnits() == r.GetUnits() && l.GetNanos() == r.GetNanos() +} + +// Negate returns the same amount with the sign negated. +func Negate(m pb.Money) pb.Money { + return pb.Money{ + Units: -m.GetUnits(), + Nanos: -m.GetNanos(), + CurrencyCode: m.GetCurrencyCode()} +} + +// Must panics if the given error is not nil. This can be used with other +// functions like: "m := Must(Sum(a,b))". +func Must(v pb.Money, err error) pb.Money { + if err != nil { + panic(err) + } + return v +} + +// Sum adds two values. Returns an error if one of the values are invalid or +// currency codes are not matching (unless currency code is unspecified for +// both). +func Sum(l, r pb.Money) (pb.Money, error) { + if !IsValid(l) || !IsValid(r) { + return pb.Money{}, ErrInvalidValue + } else if l.GetCurrencyCode() != r.GetCurrencyCode() { + return pb.Money{}, ErrMismatchingCurrency + } + units := l.GetUnits() + r.GetUnits() + nanos := l.GetNanos() + r.GetNanos() + + if (units == 0 && nanos == 0) || (units > 0 && nanos >= 0) || (units < 0 && nanos <= 0) { + // same sign + units += int64(nanos / nanosMod) + nanos = nanos % nanosMod + } else { + // different sign. nanos guaranteed to not to go over the limit + if units > 0 { + units-- + nanos += nanosMod + } else { + units++ + nanos -= nanosMod + } + } + + return pb.Money{ + Units: units, + Nanos: nanos, + CurrencyCode: l.GetCurrencyCode()}, nil +} + +// MultiplySlow is a slow multiplication operation done through adding the value +// to itself n-1 times. +func MultiplySlow(m pb.MoneyAmount, n uint32) pb.MoneyAmount { + out := m + for n > 1 { + out = Sum(out, m) + n-- + } + return out +} diff --git a/src/frontend/money/money_test.go b/src/frontend/money/money_test.go new file mode 100644 index 0000000..457227a --- /dev/null +++ b/src/frontend/money/money_test.go @@ -0,0 +1,231 @@ +package money + +import ( + "fmt" + "reflect" + "testing" + + pb "frontend/genproto" +) + +func mmc(u int64, n int32, c string) pb.Money { return pb.Money{Units: u, Nanos: n, CurrencyCode: c} } +func mm(u int64, n int32) pb.Money { return mmc(u, n, "") } + +func TestIsValid(t *testing.T) { + tests := []struct { + name string + in pb.Money + want bool + }{ + {"valid -/-", mm(-981273891273, -999999999), true}, + {"invalid -/+", mm(-981273891273, +999999999), false}, + {"valid +/+", mm(981273891273, 999999999), true}, + {"invalid +/-", mm(981273891273, -999999999), false}, + {"invalid +/+overflow", mm(3, 1000000000), false}, + {"invalid +/-overflow", mm(3, -1000000000), false}, + {"invalid -/+overflow", mm(-3, 1000000000), false}, + {"invalid -/-overflow", mm(-3, -1000000000), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsValid(tt.in); got != tt.want { + t.Errorf("IsValid(%v) = %v, want %v", tt.in, got, tt.want) + } + }) + } +} + +func TestIsZero(t *testing.T) { + tests := []struct { + name string + in pb.Money + want bool + }{ + {"zero", mm(0, 0), true}, + {"not-zero (-/+)", mm(-1, +1), false}, + {"not-zero (-/-)", mm(-1, -1), false}, + {"not-zero (+/+)", mm(+1, +1), false}, + {"not-zero (+/-)", mm(+1, -1), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsZero(tt.in); got != tt.want { + t.Errorf("IsZero(%v) = %v, want %v", tt.in, got, tt.want) + } + }) + } +} + +func TestIsPositive(t *testing.T) { + tests := []struct { + name string + in pb.Money + want bool + }{ + {"zero", mm(0, 0), false}, + {"positive (+/+)", mm(+1, +1), true}, + {"invalid (-/+)", mm(-1, +1), false}, + {"negative (-/-)", mm(-1, -1), false}, + {"invalid (+/-)", mm(+1, -1), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsPositive(tt.in); got != tt.want { + t.Errorf("IsPositive(%v) = %v, want %v", tt.in, got, tt.want) + } + }) + } +} + +func TestIsNegative(t *testing.T) { + tests := []struct { + name string + in pb.Money + want bool + }{ + {"zero", mm(0, 0), false}, + {"positive (+/+)", mm(+1, +1), false}, + {"invalid (-/+)", mm(-1, +1), false}, + {"negative (-/-)", mm(-1, -1), true}, + {"invalid (+/-)", mm(+1, -1), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsNegative(tt.in); got != tt.want { + t.Errorf("IsNegative(%v) = %v, want %v", tt.in, got, tt.want) + } + }) + } +} + +func TestAreSameCurrency(t *testing.T) { + type args struct { + l pb.Money + r pb.Money + } + tests := []struct { + name string + args args + want bool + }{ + {"both empty currency", args{mmc(1, 0, ""), mmc(2, 0, "")}, false}, + {"left empty currency", args{mmc(1, 0, ""), mmc(2, 0, "USD")}, false}, + {"right empty currency", args{mmc(1, 0, "USD"), mmc(2, 0, "")}, false}, + {"mismatching", args{mmc(1, 0, "USD"), mmc(2, 0, "CAD")}, false}, + {"matching", args{mmc(1, 0, "USD"), mmc(2, 0, "USD")}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := AreSameCurrency(tt.args.l, tt.args.r); got != tt.want { + t.Errorf("AreSameCurrency([%v],[%v]) = %v, want %v", tt.args.l, tt.args.r, got, tt.want) + } + }) + } +} + +func TestAreEquals(t *testing.T) { + type args struct { + l pb.Money + r pb.Money + } + tests := []struct { + name string + args args + want bool + }{ + {"equals", args{mmc(1, 2, "USD"), mmc(1, 2, "USD")}, true}, + {"mismatching currency", args{mmc(1, 2, "USD"), mmc(1, 2, "CAD")}, false}, + {"mismatching units", args{mmc(10, 20, "USD"), mmc(1, 20, "USD")}, false}, + {"mismatching nanos", args{mmc(1, 2, "USD"), mmc(1, 20, "USD")}, false}, + {"negated", args{mmc(1, 2, "USD"), mmc(-1, -2, "USD")}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := AreEquals(tt.args.l, tt.args.r); got != tt.want { + t.Errorf("AreEquals([%v],[%v]) = %v, want %v", tt.args.l, tt.args.r, got, tt.want) + } + }) + } +} + +func TestNegate(t *testing.T) { + tests := []struct { + name string + in pb.Money + want pb.Money + }{ + {"zero", mm(0, 0), mm(0, 0)}, + {"negative", mm(-1, -200), mm(1, 200)}, + {"positive", mm(1, 200), mm(-1, -200)}, + {"carries currency code", mmc(0, 0, "XXX"), mmc(0, 0, "XXX")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Negate(tt.in); !AreEquals(got, tt.want) { + t.Errorf("Negate([%v]) = %v, want %v", tt.in, got, tt.want) + } + }) + } +} + +func TestMust_pass(t *testing.T) { + v := Must(mm(2, 3), nil) + if !AreEquals(v, mm(2, 3)) { + t.Errorf("returned the wrong value: %v", v) + } +} + +func TestMust_panic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Logf("panic captured: %v", r) + } + }() + Must(mm(2, 3), fmt.Errorf("some error")) + t.Fatal("this should not have executed due to the panic above") +} + +func TestSum(t *testing.T) { + type args struct { + l pb.Money + r pb.Money + } + tests := []struct { + name string + args args + want pb.Money + wantErr error + }{ + {"0+0=0", args{mm(0, 0), mm(0, 0)}, mm(0, 0), nil}, + {"Error: currency code on left", args{mmc(0, 0, "XXX"), mm(0, 0)}, mm(0, 0), ErrMismatchingCurrency}, + {"Error: currency code on right", args{mm(0, 0), mmc(0, 0, "YYY")}, mm(0, 0), ErrMismatchingCurrency}, + {"Error: currency code mismatch", args{mmc(0, 0, "AAA"), mmc(0, 0, "BBB")}, mm(0, 0), ErrMismatchingCurrency}, + {"Error: invalid +/-", args{mm(+1, -1), mm(0, 0)}, mm(0, 0), ErrInvalidValue}, + {"Error: invalid -/+", args{mm(0, 0), mm(-1, +2)}, mm(0, 0), ErrInvalidValue}, + {"Error: invalid nanos", args{mm(0, 1000000000), mm(1, 0)}, mm(0, 0), ErrInvalidValue}, + {"both positive (no carry)", args{mm(2, 200000000), mm(2, 200000000)}, mm(4, 400000000), nil}, + {"both positive (nanos=max)", args{mm(2, 111111111), mm(2, 888888888)}, mm(4, 999999999), nil}, + {"both positive (carry)", args{mm(2, 200000000), mm(2, 900000000)}, mm(5, 100000000), nil}, + {"both negative (no carry)", args{mm(-2, -200000000), mm(-2, -200000000)}, mm(-4, -400000000), nil}, + {"both negative (carry)", args{mm(-2, -200000000), mm(-2, -900000000)}, mm(-5, -100000000), nil}, + {"mixed (larger positive, just decimals)", args{mm(11, 0), mm(-2, 0)}, mm(9, 0), nil}, + {"mixed (larger negative, just decimals)", args{mm(-11, 0), mm(2, 0)}, mm(-9, 0), nil}, + {"mixed (larger positive, no borrow)", args{mm(11, 100000000), mm(-2, -100000000)}, mm(9, 0), nil}, + {"mixed (larger positive, with borrow)", args{mm(11, 100000000), mm(-2, -9000000 /*.09*/)}, mm(9, 91000000 /*.091*/), nil}, + {"mixed (larger negative, no borrow)", args{mm(-11, -100000000), mm(2, 100000000)}, mm(-9, 0), nil}, + {"mixed (larger negative, with borrow)", args{mm(-11, -100000000), mm(2, 9000000 /*.09*/)}, mm(-9, -91000000 /*.091*/), nil}, + {"0+negative", args{mm(0, 0), mm(-2, -100000000)}, mm(-2, -100000000), nil}, + {"negative+0", args{mm(-2, -100000000), mm(0, 0)}, mm(-2, -100000000), nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Sum(tt.args.l, tt.args.r) + if err != tt.wantErr { + t.Errorf("Sum([%v],[%v]): expected err=\"%v\" got=\"%v\"", tt.args.l, tt.args.r, tt.wantErr, err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Sum([%v],[%v]) = %v, want %v", tt.args.l, tt.args.r, got, tt.want) + } + }) + } +}