equals.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. // Copyright 2011 Aaron Jacobs. All Rights Reserved.
  2. // Author: aaronjjacobs@gmail.com (Aaron Jacobs)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. package oglematchers
  16. import (
  17. "errors"
  18. "fmt"
  19. "math"
  20. "reflect"
  21. )
  22. // Equals(x) returns a matcher that matches values v such that v and x are
  23. // equivalent. This includes the case when the comparison v == x using Go's
  24. // built-in comparison operator is legal, but for convenience the following
  25. // rules also apply:
  26. //
  27. // * Type checking is done based on underlying types rather than actual
  28. // types, so that e.g. two aliases for string can be compared:
  29. //
  30. // type stringAlias1 string
  31. // type stringAlias2 string
  32. //
  33. // a := "taco"
  34. // b := stringAlias1("taco")
  35. // c := stringAlias2("taco")
  36. //
  37. // ExpectTrue(a == b) // Legal, passes
  38. // ExpectTrue(b == c) // Illegal, doesn't compile
  39. //
  40. // ExpectThat(a, Equals(b)) // Passes
  41. // ExpectThat(b, Equals(c)) // Passes
  42. //
  43. // * Values of numeric type are treated as if they were abstract numbers, and
  44. // compared accordingly. Therefore Equals(17) will match int(17),
  45. // int16(17), uint(17), float32(17), complex64(17), and so on.
  46. //
  47. // If you want a stricter matcher that contains no such cleverness, see
  48. // IdenticalTo instead.
  49. func Equals(x interface{}) Matcher {
  50. v := reflect.ValueOf(x)
  51. // The == operator is not defined for array or struct types.
  52. if v.Kind() == reflect.Array || v.Kind() == reflect.Struct {
  53. panic(fmt.Sprintf("oglematchers.Equals: unsupported kind %v", v.Kind()))
  54. }
  55. // The == operator is not defined for non-nil slices.
  56. if v.Kind() == reflect.Slice && v.Pointer() != uintptr(0) {
  57. panic(fmt.Sprintf("oglematchers.Equals: non-nil slice"))
  58. }
  59. return &equalsMatcher{v}
  60. }
  61. type equalsMatcher struct {
  62. expectedValue reflect.Value
  63. }
  64. ////////////////////////////////////////////////////////////////////////
  65. // Numeric types
  66. ////////////////////////////////////////////////////////////////////////
  67. func isSignedInteger(v reflect.Value) bool {
  68. k := v.Kind()
  69. return k >= reflect.Int && k <= reflect.Int64
  70. }
  71. func isUnsignedInteger(v reflect.Value) bool {
  72. k := v.Kind()
  73. return k >= reflect.Uint && k <= reflect.Uint64
  74. }
  75. func isInteger(v reflect.Value) bool {
  76. return isSignedInteger(v) || isUnsignedInteger(v)
  77. }
  78. func isFloat(v reflect.Value) bool {
  79. k := v.Kind()
  80. return k == reflect.Float32 || k == reflect.Float64
  81. }
  82. func isComplex(v reflect.Value) bool {
  83. k := v.Kind()
  84. return k == reflect.Complex64 || k == reflect.Complex128
  85. }
  86. func checkAgainstInt64(e int64, c reflect.Value) (err error) {
  87. err = errors.New("")
  88. switch {
  89. case isSignedInteger(c):
  90. if c.Int() == e {
  91. err = nil
  92. }
  93. case isUnsignedInteger(c):
  94. u := c.Uint()
  95. if u <= math.MaxInt64 && int64(u) == e {
  96. err = nil
  97. }
  98. // Turn around the various floating point types so that the checkAgainst*
  99. // functions for them can deal with precision issues.
  100. case isFloat(c), isComplex(c):
  101. return Equals(c.Interface()).Matches(e)
  102. default:
  103. err = NewFatalError("which is not numeric")
  104. }
  105. return
  106. }
  107. func checkAgainstUint64(e uint64, c reflect.Value) (err error) {
  108. err = errors.New("")
  109. switch {
  110. case isSignedInteger(c):
  111. i := c.Int()
  112. if i >= 0 && uint64(i) == e {
  113. err = nil
  114. }
  115. case isUnsignedInteger(c):
  116. if c.Uint() == e {
  117. err = nil
  118. }
  119. // Turn around the various floating point types so that the checkAgainst*
  120. // functions for them can deal with precision issues.
  121. case isFloat(c), isComplex(c):
  122. return Equals(c.Interface()).Matches(e)
  123. default:
  124. err = NewFatalError("which is not numeric")
  125. }
  126. return
  127. }
  128. func checkAgainstFloat32(e float32, c reflect.Value) (err error) {
  129. err = errors.New("")
  130. switch {
  131. case isSignedInteger(c):
  132. if float32(c.Int()) == e {
  133. err = nil
  134. }
  135. case isUnsignedInteger(c):
  136. if float32(c.Uint()) == e {
  137. err = nil
  138. }
  139. case isFloat(c):
  140. // Compare using float32 to avoid a false sense of precision; otherwise
  141. // e.g. Equals(float32(0.1)) won't match float32(0.1).
  142. if float32(c.Float()) == e {
  143. err = nil
  144. }
  145. case isComplex(c):
  146. comp := c.Complex()
  147. rl := real(comp)
  148. im := imag(comp)
  149. // Compare using float32 to avoid a false sense of precision; otherwise
  150. // e.g. Equals(float32(0.1)) won't match (0.1 + 0i).
  151. if im == 0 && float32(rl) == e {
  152. err = nil
  153. }
  154. default:
  155. err = NewFatalError("which is not numeric")
  156. }
  157. return
  158. }
  159. func checkAgainstFloat64(e float64, c reflect.Value) (err error) {
  160. err = errors.New("")
  161. ck := c.Kind()
  162. switch {
  163. case isSignedInteger(c):
  164. if float64(c.Int()) == e {
  165. err = nil
  166. }
  167. case isUnsignedInteger(c):
  168. if float64(c.Uint()) == e {
  169. err = nil
  170. }
  171. // If the actual value is lower precision, turn the comparison around so we
  172. // apply the low-precision rules. Otherwise, e.g. Equals(0.1) may not match
  173. // float32(0.1).
  174. case ck == reflect.Float32 || ck == reflect.Complex64:
  175. return Equals(c.Interface()).Matches(e)
  176. // Otherwise, compare with double precision.
  177. case isFloat(c):
  178. if c.Float() == e {
  179. err = nil
  180. }
  181. case isComplex(c):
  182. comp := c.Complex()
  183. rl := real(comp)
  184. im := imag(comp)
  185. if im == 0 && rl == e {
  186. err = nil
  187. }
  188. default:
  189. err = NewFatalError("which is not numeric")
  190. }
  191. return
  192. }
  193. func checkAgainstComplex64(e complex64, c reflect.Value) (err error) {
  194. err = errors.New("")
  195. realPart := real(e)
  196. imaginaryPart := imag(e)
  197. switch {
  198. case isInteger(c) || isFloat(c):
  199. // If we have no imaginary part, then we should just compare against the
  200. // real part. Otherwise, we can't be equal.
  201. if imaginaryPart != 0 {
  202. return
  203. }
  204. return checkAgainstFloat32(realPart, c)
  205. case isComplex(c):
  206. // Compare using complex64 to avoid a false sense of precision; otherwise
  207. // e.g. Equals(0.1 + 0i) won't match float32(0.1).
  208. if complex64(c.Complex()) == e {
  209. err = nil
  210. }
  211. default:
  212. err = NewFatalError("which is not numeric")
  213. }
  214. return
  215. }
  216. func checkAgainstComplex128(e complex128, c reflect.Value) (err error) {
  217. err = errors.New("")
  218. realPart := real(e)
  219. imaginaryPart := imag(e)
  220. switch {
  221. case isInteger(c) || isFloat(c):
  222. // If we have no imaginary part, then we should just compare against the
  223. // real part. Otherwise, we can't be equal.
  224. if imaginaryPart != 0 {
  225. return
  226. }
  227. return checkAgainstFloat64(realPart, c)
  228. case isComplex(c):
  229. if c.Complex() == e {
  230. err = nil
  231. }
  232. default:
  233. err = NewFatalError("which is not numeric")
  234. }
  235. return
  236. }
  237. ////////////////////////////////////////////////////////////////////////
  238. // Other types
  239. ////////////////////////////////////////////////////////////////////////
  240. func checkAgainstBool(e bool, c reflect.Value) (err error) {
  241. if c.Kind() != reflect.Bool {
  242. err = NewFatalError("which is not a bool")
  243. return
  244. }
  245. err = errors.New("")
  246. if c.Bool() == e {
  247. err = nil
  248. }
  249. return
  250. }
  251. func checkAgainstUintptr(e uintptr, c reflect.Value) (err error) {
  252. if c.Kind() != reflect.Uintptr {
  253. err = NewFatalError("which is not a uintptr")
  254. return
  255. }
  256. err = errors.New("")
  257. if uintptr(c.Uint()) == e {
  258. err = nil
  259. }
  260. return
  261. }
  262. func checkAgainstChan(e reflect.Value, c reflect.Value) (err error) {
  263. // Create a description of e's type, e.g. "chan int".
  264. typeStr := fmt.Sprintf("%s %s", e.Type().ChanDir(), e.Type().Elem())
  265. // Make sure c is a chan of the correct type.
  266. if c.Kind() != reflect.Chan ||
  267. c.Type().ChanDir() != e.Type().ChanDir() ||
  268. c.Type().Elem() != e.Type().Elem() {
  269. err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr))
  270. return
  271. }
  272. err = errors.New("")
  273. if c.Pointer() == e.Pointer() {
  274. err = nil
  275. }
  276. return
  277. }
  278. func checkAgainstFunc(e reflect.Value, c reflect.Value) (err error) {
  279. // Make sure c is a function.
  280. if c.Kind() != reflect.Func {
  281. err = NewFatalError("which is not a function")
  282. return
  283. }
  284. err = errors.New("")
  285. if c.Pointer() == e.Pointer() {
  286. err = nil
  287. }
  288. return
  289. }
  290. func checkAgainstMap(e reflect.Value, c reflect.Value) (err error) {
  291. // Make sure c is a map.
  292. if c.Kind() != reflect.Map {
  293. err = NewFatalError("which is not a map")
  294. return
  295. }
  296. err = errors.New("")
  297. if c.Pointer() == e.Pointer() {
  298. err = nil
  299. }
  300. return
  301. }
  302. func checkAgainstPtr(e reflect.Value, c reflect.Value) (err error) {
  303. // Create a description of e's type, e.g. "*int".
  304. typeStr := fmt.Sprintf("*%v", e.Type().Elem())
  305. // Make sure c is a pointer of the correct type.
  306. if c.Kind() != reflect.Ptr ||
  307. c.Type().Elem() != e.Type().Elem() {
  308. err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr))
  309. return
  310. }
  311. err = errors.New("")
  312. if c.Pointer() == e.Pointer() {
  313. err = nil
  314. }
  315. return
  316. }
  317. func checkAgainstSlice(e reflect.Value, c reflect.Value) (err error) {
  318. // Create a description of e's type, e.g. "[]int".
  319. typeStr := fmt.Sprintf("[]%v", e.Type().Elem())
  320. // Make sure c is a slice of the correct type.
  321. if c.Kind() != reflect.Slice ||
  322. c.Type().Elem() != e.Type().Elem() {
  323. err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr))
  324. return
  325. }
  326. err = errors.New("")
  327. if c.Pointer() == e.Pointer() {
  328. err = nil
  329. }
  330. return
  331. }
  332. func checkAgainstString(e reflect.Value, c reflect.Value) (err error) {
  333. // Make sure c is a string.
  334. if c.Kind() != reflect.String {
  335. err = NewFatalError("which is not a string")
  336. return
  337. }
  338. err = errors.New("")
  339. if c.String() == e.String() {
  340. err = nil
  341. }
  342. return
  343. }
  344. func checkAgainstUnsafePointer(e reflect.Value, c reflect.Value) (err error) {
  345. // Make sure c is a pointer.
  346. if c.Kind() != reflect.UnsafePointer {
  347. err = NewFatalError("which is not a unsafe.Pointer")
  348. return
  349. }
  350. err = errors.New("")
  351. if c.Pointer() == e.Pointer() {
  352. err = nil
  353. }
  354. return
  355. }
  356. func checkForNil(c reflect.Value) (err error) {
  357. err = errors.New("")
  358. // Make sure it is legal to call IsNil.
  359. switch c.Kind() {
  360. case reflect.Invalid:
  361. case reflect.Chan:
  362. case reflect.Func:
  363. case reflect.Interface:
  364. case reflect.Map:
  365. case reflect.Ptr:
  366. case reflect.Slice:
  367. default:
  368. err = NewFatalError("which cannot be compared to nil")
  369. return
  370. }
  371. // Ask whether the value is nil. Handle a nil literal (kind Invalid)
  372. // specially, since it's not legal to call IsNil there.
  373. if c.Kind() == reflect.Invalid || c.IsNil() {
  374. err = nil
  375. }
  376. return
  377. }
  378. ////////////////////////////////////////////////////////////////////////
  379. // Public implementation
  380. ////////////////////////////////////////////////////////////////////////
  381. func (m *equalsMatcher) Matches(candidate interface{}) error {
  382. e := m.expectedValue
  383. c := reflect.ValueOf(candidate)
  384. ek := e.Kind()
  385. switch {
  386. case ek == reflect.Bool:
  387. return checkAgainstBool(e.Bool(), c)
  388. case isSignedInteger(e):
  389. return checkAgainstInt64(e.Int(), c)
  390. case isUnsignedInteger(e):
  391. return checkAgainstUint64(e.Uint(), c)
  392. case ek == reflect.Uintptr:
  393. return checkAgainstUintptr(uintptr(e.Uint()), c)
  394. case ek == reflect.Float32:
  395. return checkAgainstFloat32(float32(e.Float()), c)
  396. case ek == reflect.Float64:
  397. return checkAgainstFloat64(e.Float(), c)
  398. case ek == reflect.Complex64:
  399. return checkAgainstComplex64(complex64(e.Complex()), c)
  400. case ek == reflect.Complex128:
  401. return checkAgainstComplex128(complex128(e.Complex()), c)
  402. case ek == reflect.Chan:
  403. return checkAgainstChan(e, c)
  404. case ek == reflect.Func:
  405. return checkAgainstFunc(e, c)
  406. case ek == reflect.Map:
  407. return checkAgainstMap(e, c)
  408. case ek == reflect.Ptr:
  409. return checkAgainstPtr(e, c)
  410. case ek == reflect.Slice:
  411. return checkAgainstSlice(e, c)
  412. case ek == reflect.String:
  413. return checkAgainstString(e, c)
  414. case ek == reflect.UnsafePointer:
  415. return checkAgainstUnsafePointer(e, c)
  416. case ek == reflect.Invalid:
  417. return checkForNil(c)
  418. }
  419. panic(fmt.Sprintf("equalsMatcher.Matches: unexpected kind: %v", ek))
  420. }
  421. func (m *equalsMatcher) Description() string {
  422. // Special case: handle nil.
  423. if !m.expectedValue.IsValid() {
  424. return "is nil"
  425. }
  426. return fmt.Sprintf("%v", m.expectedValue.Interface())
  427. }