inject.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. // Package inject provides a reflect based injector. A large application built
  2. // with dependency injection in mind will typically involve the boring work of
  3. // setting up the object graph. This library attempts to take care of this
  4. // boring work by creating and connecting the various objects. Its use involves
  5. // you seeding the object graph with some (possibly incomplete) objects, where
  6. // the underlying types have been tagged for injection. Given this, the
  7. // library will populate the objects creating new ones as necessary. It uses
  8. // singletons by default, supports optional private instances as well as named
  9. // instances.
  10. //
  11. // It works using Go's reflection package and is inherently limited in what it
  12. // can do as opposed to a code-gen system with respect to private fields.
  13. //
  14. // The usage pattern for the library involves struct tags. It requires the tag
  15. // format used by the various standard libraries, like json, xml etc. It
  16. // involves tags in one of the three forms below:
  17. //
  18. // `inject:""`
  19. // `inject:"private"`
  20. // `inject:"dev logger"`
  21. //
  22. // The first no value syntax is for the common case of a singleton dependency
  23. // of the associated type. The second triggers creation of a private instance
  24. // for the associated type. Finally the last form is asking for a named
  25. // dependency called "dev logger".
  26. package inject
  27. import (
  28. "bytes"
  29. "fmt"
  30. "math/rand"
  31. "reflect"
  32. "github.com/facebookgo/structtag"
  33. )
  34. // Logger allows for simple logging as inject traverses and populates the
  35. // object graph.
  36. type Logger interface {
  37. Debugf(format string, v ...interface{})
  38. }
  39. // Populate is a short-hand for populating a graph with the given incomplete
  40. // object values.
  41. func Populate(values ...interface{}) error {
  42. var g Graph
  43. for _, v := range values {
  44. if err := g.Provide(&Object{Value: v}); err != nil {
  45. return err
  46. }
  47. }
  48. return g.Populate()
  49. }
  50. // An Object in the Graph.
  51. type Object struct {
  52. Value interface{}
  53. Name string // Optional
  54. Complete bool // If true, the Value will be considered complete
  55. Fields map[string]*Object // Populated with the field names that were injected and their corresponding *Object.
  56. reflectType reflect.Type
  57. reflectValue reflect.Value
  58. private bool // If true, the Value will not be used and will only be populated
  59. created bool // If true, the Object was created by us
  60. embedded bool // If true, the Object is an embedded struct provided internally
  61. }
  62. // String representation suitable for human consumption.
  63. func (o *Object) String() string {
  64. var buf bytes.Buffer
  65. fmt.Fprint(&buf, o.reflectType)
  66. if o.Name != "" {
  67. fmt.Fprintf(&buf, " named %s", o.Name)
  68. }
  69. return buf.String()
  70. }
  71. func (o *Object) addDep(field string, dep *Object) {
  72. if o.Fields == nil {
  73. o.Fields = make(map[string]*Object)
  74. }
  75. o.Fields[field] = dep
  76. }
  77. // The Graph of Objects.
  78. type Graph struct {
  79. Logger Logger // Optional, will trigger debug logging.
  80. unnamed []*Object
  81. unnamedType map[reflect.Type]bool
  82. named map[string]*Object
  83. }
  84. // Provide objects to the Graph. The Object documentation describes
  85. // the impact of various fields.
  86. func (g *Graph) Provide(objects ...*Object) error {
  87. for _, o := range objects {
  88. o.reflectType = reflect.TypeOf(o.Value)
  89. o.reflectValue = reflect.ValueOf(o.Value)
  90. if o.Fields != nil {
  91. return fmt.Errorf(
  92. "fields were specified on object %s when it was provided",
  93. o,
  94. )
  95. }
  96. if o.Name == "" {
  97. if !isStructPtr(o.reflectType) {
  98. return fmt.Errorf(
  99. "expected unnamed object value to be a pointer to a struct but got type %s "+
  100. "with value %v",
  101. o.reflectType,
  102. o.Value,
  103. )
  104. }
  105. if !o.private {
  106. if g.unnamedType == nil {
  107. g.unnamedType = make(map[reflect.Type]bool)
  108. }
  109. if g.unnamedType[o.reflectType] {
  110. return fmt.Errorf(
  111. "provided two unnamed instances of type *%s.%s",
  112. o.reflectType.Elem().PkgPath(), o.reflectType.Elem().Name(),
  113. )
  114. }
  115. g.unnamedType[o.reflectType] = true
  116. }
  117. g.unnamed = append(g.unnamed, o)
  118. } else {
  119. if g.named == nil {
  120. g.named = make(map[string]*Object)
  121. }
  122. if g.named[o.Name] != nil {
  123. return fmt.Errorf("provided two instances named %s", o.Name)
  124. }
  125. g.named[o.Name] = o
  126. }
  127. if g.Logger != nil {
  128. if o.created {
  129. g.Logger.Debugf("created %s", o)
  130. } else if o.embedded {
  131. g.Logger.Debugf("provided embedded %s", o)
  132. } else {
  133. g.Logger.Debugf("provided %s", o)
  134. }
  135. }
  136. }
  137. return nil
  138. }
  139. // Populate the incomplete Objects.
  140. func (g *Graph) Populate() error {
  141. for _, o := range g.named {
  142. if o.Complete {
  143. continue
  144. }
  145. if err := g.populateExplicit(o); err != nil {
  146. return err
  147. }
  148. }
  149. // We append and modify our slice as we go along, so we don't use a standard
  150. // range loop, and do a single pass thru each object in our graph.
  151. i := 0
  152. for {
  153. if i == len(g.unnamed) {
  154. break
  155. }
  156. o := g.unnamed[i]
  157. i++
  158. if o.Complete {
  159. continue
  160. }
  161. if err := g.populateExplicit(o); err != nil {
  162. return err
  163. }
  164. }
  165. // A Second pass handles injecting Interface values to ensure we have created
  166. // all concrete types first.
  167. for _, o := range g.unnamed {
  168. if o.Complete {
  169. continue
  170. }
  171. if err := g.populateUnnamedInterface(o); err != nil {
  172. return err
  173. }
  174. }
  175. for _, o := range g.named {
  176. if o.Complete {
  177. continue
  178. }
  179. if err := g.populateUnnamedInterface(o); err != nil {
  180. return err
  181. }
  182. }
  183. return nil
  184. }
  185. func (g *Graph) populateExplicit(o *Object) error {
  186. // Ignore named value types.
  187. if o.Name != "" && !isStructPtr(o.reflectType) {
  188. return nil
  189. }
  190. StructLoop:
  191. for i := 0; i < o.reflectValue.Elem().NumField(); i++ {
  192. field := o.reflectValue.Elem().Field(i)
  193. fieldType := field.Type()
  194. fieldTag := o.reflectType.Elem().Field(i).Tag
  195. fieldName := o.reflectType.Elem().Field(i).Name
  196. tag, err := parseTag(string(fieldTag))
  197. if err != nil {
  198. return fmt.Errorf(
  199. "unexpected tag format `%s` for field %s in type %s",
  200. string(fieldTag),
  201. o.reflectType.Elem().Field(i).Name,
  202. o.reflectType,
  203. )
  204. }
  205. // Skip fields without a tag.
  206. if tag == nil {
  207. continue
  208. }
  209. // Cannot be used with unexported fields.
  210. if !field.CanSet() {
  211. return fmt.Errorf(
  212. "inject requested on unexported field %s in type %s",
  213. o.reflectType.Elem().Field(i).Name,
  214. o.reflectType,
  215. )
  216. }
  217. // Inline tag on anything besides a struct is considered invalid.
  218. if tag.Inline && fieldType.Kind() != reflect.Struct {
  219. return fmt.Errorf(
  220. "inline requested on non inlined field %s in type %s",
  221. o.reflectType.Elem().Field(i).Name,
  222. o.reflectType,
  223. )
  224. }
  225. // Don't overwrite existing values.
  226. if !isNilOrZero(field, fieldType) {
  227. continue
  228. }
  229. // Named injects must have been explicitly provided.
  230. if tag.Name != "" {
  231. existing := g.named[tag.Name]
  232. if existing == nil {
  233. return fmt.Errorf(
  234. "did not find object named %s required by field %s in type %s",
  235. tag.Name,
  236. o.reflectType.Elem().Field(i).Name,
  237. o.reflectType,
  238. )
  239. }
  240. if !existing.reflectType.AssignableTo(fieldType) {
  241. return fmt.Errorf(
  242. "object named %s of type %s is not assignable to field %s (%s) in type %s",
  243. tag.Name,
  244. fieldType,
  245. o.reflectType.Elem().Field(i).Name,
  246. existing.reflectType,
  247. o.reflectType,
  248. )
  249. }
  250. field.Set(reflect.ValueOf(existing.Value))
  251. if g.Logger != nil {
  252. g.Logger.Debugf(
  253. "assigned %s to field %s in %s",
  254. existing,
  255. o.reflectType.Elem().Field(i).Name,
  256. o,
  257. )
  258. }
  259. o.addDep(fieldName, existing)
  260. continue StructLoop
  261. }
  262. // Inline struct values indicate we want to traverse into it, but not
  263. // inject itself. We require an explicit "inline" tag for this to work.
  264. if fieldType.Kind() == reflect.Struct {
  265. if tag.Private {
  266. return fmt.Errorf(
  267. "cannot use private inject on inline struct on field %s in type %s",
  268. o.reflectType.Elem().Field(i).Name,
  269. o.reflectType,
  270. )
  271. }
  272. if !tag.Inline {
  273. return fmt.Errorf(
  274. "inline struct on field %s in type %s requires an explicit \"inline\" tag",
  275. o.reflectType.Elem().Field(i).Name,
  276. o.reflectType,
  277. )
  278. }
  279. err := g.Provide(&Object{
  280. Value: field.Addr().Interface(),
  281. private: true,
  282. embedded: o.reflectType.Elem().Field(i).Anonymous,
  283. })
  284. if err != nil {
  285. return err
  286. }
  287. continue
  288. }
  289. // Interface injection is handled in a second pass.
  290. if fieldType.Kind() == reflect.Interface {
  291. continue
  292. }
  293. // Maps are created and required to be private.
  294. if fieldType.Kind() == reflect.Map {
  295. if !tag.Private {
  296. return fmt.Errorf(
  297. "inject on map field %s in type %s must be named or private",
  298. o.reflectType.Elem().Field(i).Name,
  299. o.reflectType,
  300. )
  301. }
  302. field.Set(reflect.MakeMap(fieldType))
  303. if g.Logger != nil {
  304. g.Logger.Debugf(
  305. "made map for field %s in %s",
  306. o.reflectType.Elem().Field(i).Name,
  307. o,
  308. )
  309. }
  310. continue
  311. }
  312. // Can only inject Pointers from here on.
  313. if !isStructPtr(fieldType) {
  314. return fmt.Errorf(
  315. "found inject tag on unsupported field %s in type %s",
  316. o.reflectType.Elem().Field(i).Name,
  317. o.reflectType,
  318. )
  319. }
  320. // Unless it's a private inject, we'll look for an existing instance of the
  321. // same type.
  322. if !tag.Private {
  323. for _, existing := range g.unnamed {
  324. if existing.private {
  325. continue
  326. }
  327. if existing.reflectType.AssignableTo(fieldType) {
  328. field.Set(reflect.ValueOf(existing.Value))
  329. if g.Logger != nil {
  330. g.Logger.Debugf(
  331. "assigned existing %s to field %s in %s",
  332. existing,
  333. o.reflectType.Elem().Field(i).Name,
  334. o,
  335. )
  336. }
  337. o.addDep(fieldName, existing)
  338. continue StructLoop
  339. }
  340. }
  341. }
  342. newValue := reflect.New(fieldType.Elem())
  343. newObject := &Object{
  344. Value: newValue.Interface(),
  345. private: tag.Private,
  346. created: true,
  347. }
  348. // Add the newly ceated object to the known set of objects.
  349. err = g.Provide(newObject)
  350. if err != nil {
  351. return err
  352. }
  353. // Finally assign the newly created object to our field.
  354. field.Set(newValue)
  355. if g.Logger != nil {
  356. g.Logger.Debugf(
  357. "assigned newly created %s to field %s in %s",
  358. newObject,
  359. o.reflectType.Elem().Field(i).Name,
  360. o,
  361. )
  362. }
  363. o.addDep(fieldName, newObject)
  364. }
  365. return nil
  366. }
  367. func (g *Graph) populateUnnamedInterface(o *Object) error {
  368. // Ignore named value types.
  369. if o.Name != "" && !isStructPtr(o.reflectType) {
  370. return nil
  371. }
  372. for i := 0; i < o.reflectValue.Elem().NumField(); i++ {
  373. field := o.reflectValue.Elem().Field(i)
  374. fieldType := field.Type()
  375. fieldTag := o.reflectType.Elem().Field(i).Tag
  376. fieldName := o.reflectType.Elem().Field(i).Name
  377. tag, err := parseTag(string(fieldTag))
  378. if err != nil {
  379. return fmt.Errorf(
  380. "unexpected tag format `%s` for field %s in type %s",
  381. string(fieldTag),
  382. o.reflectType.Elem().Field(i).Name,
  383. o.reflectType,
  384. )
  385. }
  386. // Skip fields without a tag.
  387. if tag == nil {
  388. continue
  389. }
  390. // We only handle interface injection here. Other cases including errors
  391. // are handled in the first pass when we inject pointers.
  392. if fieldType.Kind() != reflect.Interface {
  393. continue
  394. }
  395. // Interface injection can't be private because we can't instantiate new
  396. // instances of an interface.
  397. if tag.Private {
  398. return fmt.Errorf(
  399. "found private inject tag on interface field %s in type %s",
  400. o.reflectType.Elem().Field(i).Name,
  401. o.reflectType,
  402. )
  403. }
  404. // Don't overwrite existing values.
  405. if !isNilOrZero(field, fieldType) {
  406. continue
  407. }
  408. // Named injects must have already been handled in populateExplicit.
  409. if tag.Name != "" {
  410. panic(fmt.Sprintf("unhandled named instance with name %s", tag.Name))
  411. }
  412. // Find one, and only one assignable value for the field.
  413. var found *Object
  414. for _, existing := range g.unnamed {
  415. if existing.private {
  416. continue
  417. }
  418. if existing.reflectType.AssignableTo(fieldType) {
  419. if found != nil {
  420. return fmt.Errorf(
  421. "found two assignable values for field %s in type %s. one type "+
  422. "%s with value %v and another type %s with value %v",
  423. o.reflectType.Elem().Field(i).Name,
  424. o.reflectType,
  425. found.reflectType,
  426. found.Value,
  427. existing.reflectType,
  428. existing.reflectValue,
  429. )
  430. }
  431. found = existing
  432. field.Set(reflect.ValueOf(existing.Value))
  433. if g.Logger != nil {
  434. g.Logger.Debugf(
  435. "assigned existing %s to interface field %s in %s",
  436. existing,
  437. o.reflectType.Elem().Field(i).Name,
  438. o,
  439. )
  440. }
  441. o.addDep(fieldName, existing)
  442. }
  443. }
  444. // If we didn't find an assignable value, we're missing something.
  445. if found == nil {
  446. return fmt.Errorf(
  447. "found no assignable value for field %s in type %s",
  448. o.reflectType.Elem().Field(i).Name,
  449. o.reflectType,
  450. )
  451. }
  452. }
  453. return nil
  454. }
  455. // Objects returns all known objects, named as well as unnamed. The returned
  456. // elements are not in a stable order.
  457. func (g *Graph) Objects() []*Object {
  458. objects := make([]*Object, 0, len(g.unnamed)+len(g.named))
  459. for _, o := range g.unnamed {
  460. if !o.embedded {
  461. objects = append(objects, o)
  462. }
  463. }
  464. for _, o := range g.named {
  465. if !o.embedded {
  466. objects = append(objects, o)
  467. }
  468. }
  469. // randomize to prevent callers from relying on ordering
  470. for i := 0; i < len(objects); i++ {
  471. j := rand.Intn(i + 1)
  472. objects[i], objects[j] = objects[j], objects[i]
  473. }
  474. return objects
  475. }
  476. var (
  477. injectOnly = &tag{}
  478. injectPrivate = &tag{Private: true}
  479. injectInline = &tag{Inline: true}
  480. )
  481. type tag struct {
  482. Name string
  483. Inline bool
  484. Private bool
  485. }
  486. func parseTag(t string) (*tag, error) {
  487. found, value, err := structtag.Extract("inject", t)
  488. if err != nil {
  489. return nil, err
  490. }
  491. if !found {
  492. return nil, nil
  493. }
  494. if value == "" {
  495. return injectOnly, nil
  496. }
  497. if value == "inline" {
  498. return injectInline, nil
  499. }
  500. if value == "private" {
  501. return injectPrivate, nil
  502. }
  503. return &tag{Name: value}, nil
  504. }
  505. func isStructPtr(t reflect.Type) bool {
  506. return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
  507. }
  508. func isNilOrZero(v reflect.Value, t reflect.Type) bool {
  509. switch v.Kind() {
  510. default:
  511. return reflect.DeepEqual(v.Interface(), reflect.Zero(t).Interface())
  512. case reflect.Interface, reflect.Ptr:
  513. return v.IsNil()
  514. }
  515. }