conn.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ldap
  5. import (
  6. "crypto/tls"
  7. "errors"
  8. "fmt"
  9. "log"
  10. "net"
  11. "sync"
  12. "time"
  13. "gopkg.in/asn1-ber.v1"
  14. )
  15. const (
  16. MessageQuit = 0
  17. MessageRequest = 1
  18. MessageResponse = 2
  19. MessageFinish = 3
  20. )
  21. type messagePacket struct {
  22. Op int
  23. MessageID int64
  24. Packet *ber.Packet
  25. Channel chan *ber.Packet
  26. }
  27. type sendMessageFlags uint
  28. const (
  29. startTLS sendMessageFlags = 1 << iota
  30. )
  31. // Conn represents an LDAP Connection
  32. type Conn struct {
  33. conn net.Conn
  34. isTLS bool
  35. isClosing bool
  36. isStartingTLS bool
  37. Debug debugging
  38. chanConfirm chan bool
  39. chanResults map[int64]chan *ber.Packet
  40. chanMessage chan *messagePacket
  41. chanMessageID chan int64
  42. wgSender sync.WaitGroup
  43. wgClose sync.WaitGroup
  44. once sync.Once
  45. outstandingRequests uint
  46. messageMutex sync.Mutex
  47. }
  48. var _ Client = &Conn{}
  49. // DefaultTimeout is a package-level variable that sets the timeout value
  50. // used for the Dial and DialTLS methods.
  51. //
  52. // WARNING: since this is a package-level variable, setting this value from
  53. // multiple places will probably result in undesired behaviour.
  54. var DefaultTimeout = 60 * time.Second
  55. // Dial connects to the given address on the given network using net.Dial
  56. // and then returns a new Conn for the connection.
  57. func Dial(network, addr string) (*Conn, error) {
  58. c, err := net.DialTimeout(network, addr, DefaultTimeout)
  59. if err != nil {
  60. return nil, NewError(ErrorNetwork, err)
  61. }
  62. conn := NewConn(c, false)
  63. conn.Start()
  64. return conn, nil
  65. }
  66. // DialTLS connects to the given address on the given network using tls.Dial
  67. // and then returns a new Conn for the connection.
  68. func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
  69. dc, err := net.DialTimeout(network, addr, DefaultTimeout)
  70. if err != nil {
  71. return nil, NewError(ErrorNetwork, err)
  72. }
  73. c := tls.Client(dc, config)
  74. err = c.Handshake()
  75. if err != nil {
  76. // Handshake error, close the established connection before we return an error
  77. dc.Close()
  78. return nil, NewError(ErrorNetwork, err)
  79. }
  80. conn := NewConn(c, true)
  81. conn.Start()
  82. return conn, nil
  83. }
  84. // NewConn returns a new Conn using conn for network I/O.
  85. func NewConn(conn net.Conn, isTLS bool) *Conn {
  86. return &Conn{
  87. conn: conn,
  88. chanConfirm: make(chan bool),
  89. chanMessageID: make(chan int64),
  90. chanMessage: make(chan *messagePacket, 10),
  91. chanResults: map[int64]chan *ber.Packet{},
  92. isTLS: isTLS,
  93. }
  94. }
  95. func (l *Conn) Start() {
  96. go l.reader()
  97. go l.processMessages()
  98. l.wgClose.Add(1)
  99. }
  100. // Close closes the connection.
  101. func (l *Conn) Close() {
  102. l.once.Do(func() {
  103. l.isClosing = true
  104. l.wgSender.Wait()
  105. l.Debug.Printf("Sending quit message and waiting for confirmation")
  106. l.chanMessage <- &messagePacket{Op: MessageQuit}
  107. <-l.chanConfirm
  108. close(l.chanMessage)
  109. l.Debug.Printf("Closing network connection")
  110. if err := l.conn.Close(); err != nil {
  111. log.Print(err)
  112. }
  113. l.wgClose.Done()
  114. })
  115. l.wgClose.Wait()
  116. }
  117. // Returns the next available messageID
  118. func (l *Conn) nextMessageID() int64 {
  119. if l.chanMessageID != nil {
  120. if messageID, ok := <-l.chanMessageID; ok {
  121. return messageID
  122. }
  123. }
  124. return 0
  125. }
  126. // StartTLS sends the command to start a TLS session and then creates a new TLS Client
  127. func (l *Conn) StartTLS(config *tls.Config) error {
  128. messageID := l.nextMessageID()
  129. if l.isTLS {
  130. return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
  131. }
  132. packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
  133. packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
  134. request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
  135. request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
  136. packet.AppendChild(request)
  137. l.Debug.PrintPacket(packet)
  138. channel, err := l.sendMessageWithFlags(packet, startTLS)
  139. if err != nil {
  140. return err
  141. }
  142. if channel == nil {
  143. return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
  144. }
  145. l.Debug.Printf("%d: waiting for response", messageID)
  146. packet = <-channel
  147. l.Debug.Printf("%d: got response %p", messageID, packet)
  148. l.finishMessage(messageID)
  149. if l.Debug {
  150. if err := addLDAPDescriptions(packet); err != nil {
  151. l.Close()
  152. return err
  153. }
  154. ber.PrintPacket(packet)
  155. }
  156. if resultCode, message := getLDAPResultCode(packet); resultCode == LDAPResultSuccess {
  157. conn := tls.Client(l.conn, config)
  158. if err := conn.Handshake(); err != nil {
  159. l.Close()
  160. return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", err))
  161. }
  162. l.isTLS = true
  163. l.conn = conn
  164. } else {
  165. return NewError(resultCode, fmt.Errorf("ldap: cannot StartTLS (%s)", message))
  166. }
  167. go l.reader()
  168. return nil
  169. }
  170. func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) {
  171. return l.sendMessageWithFlags(packet, 0)
  172. }
  173. func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (chan *ber.Packet, error) {
  174. if l.isClosing {
  175. return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
  176. }
  177. l.messageMutex.Lock()
  178. l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
  179. if l.isStartingTLS {
  180. l.messageMutex.Unlock()
  181. return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase."))
  182. }
  183. if flags&startTLS != 0 {
  184. if l.outstandingRequests != 0 {
  185. l.messageMutex.Unlock()
  186. return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
  187. } else {
  188. l.isStartingTLS = true
  189. }
  190. }
  191. l.outstandingRequests++
  192. l.messageMutex.Unlock()
  193. out := make(chan *ber.Packet)
  194. message := &messagePacket{
  195. Op: MessageRequest,
  196. MessageID: packet.Children[0].Value.(int64),
  197. Packet: packet,
  198. Channel: out,
  199. }
  200. l.sendProcessMessage(message)
  201. return out, nil
  202. }
  203. func (l *Conn) finishMessage(messageID int64) {
  204. if l.isClosing {
  205. return
  206. }
  207. l.messageMutex.Lock()
  208. l.outstandingRequests--
  209. if l.isStartingTLS {
  210. l.isStartingTLS = false
  211. }
  212. l.messageMutex.Unlock()
  213. message := &messagePacket{
  214. Op: MessageFinish,
  215. MessageID: messageID,
  216. }
  217. l.sendProcessMessage(message)
  218. }
  219. func (l *Conn) sendProcessMessage(message *messagePacket) bool {
  220. if l.isClosing {
  221. return false
  222. }
  223. l.wgSender.Add(1)
  224. l.chanMessage <- message
  225. l.wgSender.Done()
  226. return true
  227. }
  228. func (l *Conn) processMessages() {
  229. defer func() {
  230. if err := recover(); err != nil {
  231. log.Printf("ldap: recovered panic in processMessages: %v", err)
  232. }
  233. for messageID, channel := range l.chanResults {
  234. l.Debug.Printf("Closing channel for MessageID %d", messageID)
  235. close(channel)
  236. delete(l.chanResults, messageID)
  237. }
  238. close(l.chanMessageID)
  239. l.chanConfirm <- true
  240. close(l.chanConfirm)
  241. }()
  242. var messageID int64 = 1
  243. for {
  244. select {
  245. case l.chanMessageID <- messageID:
  246. messageID++
  247. case messagePacket, ok := <-l.chanMessage:
  248. if !ok {
  249. l.Debug.Printf("Shutting down - message channel is closed")
  250. return
  251. }
  252. switch messagePacket.Op {
  253. case MessageQuit:
  254. l.Debug.Printf("Shutting down - quit message received")
  255. return
  256. case MessageRequest:
  257. // Add to message list and write to network
  258. l.Debug.Printf("Sending message %d", messagePacket.MessageID)
  259. l.chanResults[messagePacket.MessageID] = messagePacket.Channel
  260. // go routine
  261. buf := messagePacket.Packet.Bytes()
  262. _, err := l.conn.Write(buf)
  263. if err != nil {
  264. l.Debug.Printf("Error Sending Message: %s", err.Error())
  265. break
  266. }
  267. case MessageResponse:
  268. l.Debug.Printf("Receiving message %d", messagePacket.MessageID)
  269. if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok {
  270. chanResult <- messagePacket.Packet
  271. } else {
  272. log.Printf("Received unexpected message %d", messagePacket.MessageID)
  273. ber.PrintPacket(messagePacket.Packet)
  274. }
  275. case MessageFinish:
  276. // Remove from message list
  277. l.Debug.Printf("Finished message %d", messagePacket.MessageID)
  278. close(l.chanResults[messagePacket.MessageID])
  279. delete(l.chanResults, messagePacket.MessageID)
  280. }
  281. }
  282. }
  283. }
  284. func (l *Conn) reader() {
  285. cleanstop := false
  286. defer func() {
  287. if err := recover(); err != nil {
  288. log.Printf("ldap: recovered panic in reader: %v", err)
  289. }
  290. if !cleanstop {
  291. l.Close()
  292. }
  293. }()
  294. for {
  295. if cleanstop {
  296. l.Debug.Printf("reader clean stopping (without closing the connection)")
  297. return
  298. }
  299. packet, err := ber.ReadPacket(l.conn)
  300. if err != nil {
  301. // A read error is expected here if we are closing the connection...
  302. if !l.isClosing {
  303. l.Debug.Printf("reader error: %s", err.Error())
  304. }
  305. return
  306. }
  307. addLDAPDescriptions(packet)
  308. if len(packet.Children) == 0 {
  309. l.Debug.Printf("Received bad ldap packet")
  310. continue
  311. }
  312. l.messageMutex.Lock()
  313. if l.isStartingTLS {
  314. cleanstop = true
  315. }
  316. l.messageMutex.Unlock()
  317. message := &messagePacket{
  318. Op: MessageResponse,
  319. MessageID: packet.Children[0].Value.(int64),
  320. Packet: packet,
  321. }
  322. if !l.sendProcessMessage(message) {
  323. return
  324. }
  325. }
  326. }