conn.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ldap
  5. import (
  6. "crypto/tls"
  7. "errors"
  8. "fmt"
  9. "log"
  10. "net"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. "gopkg.in/asn1-ber.v1"
  15. )
  16. const (
  17. // MessageQuit causes the processMessages loop to exit
  18. MessageQuit = 0
  19. // MessageRequest sends a request to the server
  20. MessageRequest = 1
  21. // MessageResponse receives a response from the server
  22. MessageResponse = 2
  23. // MessageFinish indicates the client considers a particular message ID to be finished
  24. MessageFinish = 3
  25. // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
  26. MessageTimeout = 4
  27. )
  28. // PacketResponse contains the packet or error encountered reading a response
  29. type PacketResponse struct {
  30. // Packet is the packet read from the server
  31. Packet *ber.Packet
  32. // Error is an error encountered while reading
  33. Error error
  34. }
  35. // ReadPacket returns the packet or an error
  36. func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
  37. if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
  38. return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
  39. }
  40. return pr.Packet, pr.Error
  41. }
  42. type messageContext struct {
  43. id int64
  44. // close(done) should only be called from finishMessage()
  45. done chan struct{}
  46. // close(responses) should only be called from processMessages(), and only sent to from sendResponse()
  47. responses chan *PacketResponse
  48. }
  49. // sendResponse should only be called within the processMessages() loop which
  50. // is also responsible for closing the responses channel.
  51. func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
  52. select {
  53. case msgCtx.responses <- packet:
  54. // Successfully sent packet to message handler.
  55. case <-msgCtx.done:
  56. // The request handler is done and will not receive more
  57. // packets.
  58. }
  59. }
  60. type messagePacket struct {
  61. Op int
  62. MessageID int64
  63. Packet *ber.Packet
  64. Context *messageContext
  65. }
  66. type sendMessageFlags uint
  67. const (
  68. startTLS sendMessageFlags = 1 << iota
  69. )
  70. // Conn represents an LDAP Connection
  71. type Conn struct {
  72. conn net.Conn
  73. isTLS bool
  74. closing uint32
  75. closeErr atomicValue
  76. isStartingTLS bool
  77. Debug debugging
  78. chanConfirm chan struct{}
  79. messageContexts map[int64]*messageContext
  80. chanMessage chan *messagePacket
  81. chanMessageID chan int64
  82. wgClose sync.WaitGroup
  83. outstandingRequests uint
  84. messageMutex sync.Mutex
  85. requestTimeout int64
  86. }
  87. var _ Client = &Conn{}
  88. // DefaultTimeout is a package-level variable that sets the timeout value
  89. // used for the Dial and DialTLS methods.
  90. //
  91. // WARNING: since this is a package-level variable, setting this value from
  92. // multiple places will probably result in undesired behaviour.
  93. var DefaultTimeout = 60 * time.Second
  94. // Dial connects to the given address on the given network using net.Dial
  95. // and then returns a new Conn for the connection.
  96. func Dial(network, addr string) (*Conn, error) {
  97. c, err := net.DialTimeout(network, addr, DefaultTimeout)
  98. if err != nil {
  99. return nil, NewError(ErrorNetwork, err)
  100. }
  101. conn := NewConn(c, false)
  102. conn.Start()
  103. return conn, nil
  104. }
  105. // DialTLS connects to the given address on the given network using tls.Dial
  106. // and then returns a new Conn for the connection.
  107. func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
  108. dc, err := net.DialTimeout(network, addr, DefaultTimeout)
  109. if err != nil {
  110. return nil, NewError(ErrorNetwork, err)
  111. }
  112. c := tls.Client(dc, config)
  113. err = c.Handshake()
  114. if err != nil {
  115. // Handshake error, close the established connection before we return an error
  116. dc.Close()
  117. return nil, NewError(ErrorNetwork, err)
  118. }
  119. conn := NewConn(c, true)
  120. conn.Start()
  121. return conn, nil
  122. }
  123. // NewConn returns a new Conn using conn for network I/O.
  124. func NewConn(conn net.Conn, isTLS bool) *Conn {
  125. return &Conn{
  126. conn: conn,
  127. chanConfirm: make(chan struct{}),
  128. chanMessageID: make(chan int64),
  129. chanMessage: make(chan *messagePacket, 10),
  130. messageContexts: map[int64]*messageContext{},
  131. requestTimeout: 0,
  132. isTLS: isTLS,
  133. }
  134. }
  135. // Start initializes goroutines to read responses and process messages
  136. func (l *Conn) Start() {
  137. go l.reader()
  138. go l.processMessages()
  139. l.wgClose.Add(1)
  140. }
  141. // isClosing returns whether or not we're currently closing.
  142. func (l *Conn) isClosing() bool {
  143. return atomic.LoadUint32(&l.closing) == 1
  144. }
  145. // setClosing sets the closing value to true
  146. func (l *Conn) setClosing() bool {
  147. return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
  148. }
  149. // Close closes the connection.
  150. func (l *Conn) Close() {
  151. l.messageMutex.Lock()
  152. defer l.messageMutex.Unlock()
  153. if l.setClosing() {
  154. l.Debug.Printf("Sending quit message and waiting for confirmation")
  155. l.chanMessage <- &messagePacket{Op: MessageQuit}
  156. <-l.chanConfirm
  157. close(l.chanMessage)
  158. l.Debug.Printf("Closing network connection")
  159. if err := l.conn.Close(); err != nil {
  160. log.Println(err)
  161. }
  162. l.wgClose.Done()
  163. }
  164. l.wgClose.Wait()
  165. }
  166. // SetTimeout sets the time after a request is sent that a MessageTimeout triggers
  167. func (l *Conn) SetTimeout(timeout time.Duration) {
  168. if timeout > 0 {
  169. atomic.StoreInt64(&l.requestTimeout, int64(timeout))
  170. }
  171. }
  172. // Returns the next available messageID
  173. func (l *Conn) nextMessageID() int64 {
  174. if messageID, ok := <-l.chanMessageID; ok {
  175. return messageID
  176. }
  177. return 0
  178. }
  179. // StartTLS sends the command to start a TLS session and then creates a new TLS Client
  180. func (l *Conn) StartTLS(config *tls.Config) error {
  181. if l.isTLS {
  182. return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
  183. }
  184. packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
  185. packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
  186. request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
  187. request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
  188. packet.AppendChild(request)
  189. l.Debug.PrintPacket(packet)
  190. msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
  191. if err != nil {
  192. return err
  193. }
  194. defer l.finishMessage(msgCtx)
  195. l.Debug.Printf("%d: waiting for response", msgCtx.id)
  196. packetResponse, ok := <-msgCtx.responses
  197. if !ok {
  198. return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
  199. }
  200. packet, err = packetResponse.ReadPacket()
  201. l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
  202. if err != nil {
  203. return err
  204. }
  205. if l.Debug {
  206. if err := addLDAPDescriptions(packet); err != nil {
  207. l.Close()
  208. return err
  209. }
  210. ber.PrintPacket(packet)
  211. }
  212. if resultCode, message := getLDAPResultCode(packet); resultCode == LDAPResultSuccess {
  213. conn := tls.Client(l.conn, config)
  214. if err := conn.Handshake(); err != nil {
  215. l.Close()
  216. return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", err))
  217. }
  218. l.isTLS = true
  219. l.conn = conn
  220. } else {
  221. return NewError(resultCode, fmt.Errorf("ldap: cannot StartTLS (%s)", message))
  222. }
  223. go l.reader()
  224. return nil
  225. }
  226. func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
  227. return l.sendMessageWithFlags(packet, 0)
  228. }
  229. func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
  230. if l.isClosing() {
  231. return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
  232. }
  233. l.messageMutex.Lock()
  234. l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
  235. if l.isStartingTLS {
  236. l.messageMutex.Unlock()
  237. return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
  238. }
  239. if flags&startTLS != 0 {
  240. if l.outstandingRequests != 0 {
  241. l.messageMutex.Unlock()
  242. return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
  243. }
  244. l.isStartingTLS = true
  245. }
  246. l.outstandingRequests++
  247. l.messageMutex.Unlock()
  248. responses := make(chan *PacketResponse)
  249. messageID := packet.Children[0].Value.(int64)
  250. message := &messagePacket{
  251. Op: MessageRequest,
  252. MessageID: messageID,
  253. Packet: packet,
  254. Context: &messageContext{
  255. id: messageID,
  256. done: make(chan struct{}),
  257. responses: responses,
  258. },
  259. }
  260. l.sendProcessMessage(message)
  261. return message.Context, nil
  262. }
  263. func (l *Conn) finishMessage(msgCtx *messageContext) {
  264. close(msgCtx.done)
  265. if l.isClosing() {
  266. return
  267. }
  268. l.messageMutex.Lock()
  269. l.outstandingRequests--
  270. if l.isStartingTLS {
  271. l.isStartingTLS = false
  272. }
  273. l.messageMutex.Unlock()
  274. message := &messagePacket{
  275. Op: MessageFinish,
  276. MessageID: msgCtx.id,
  277. }
  278. l.sendProcessMessage(message)
  279. }
  280. func (l *Conn) sendProcessMessage(message *messagePacket) bool {
  281. l.messageMutex.Lock()
  282. defer l.messageMutex.Unlock()
  283. if l.isClosing() {
  284. return false
  285. }
  286. l.chanMessage <- message
  287. return true
  288. }
  289. func (l *Conn) processMessages() {
  290. defer func() {
  291. if err := recover(); err != nil {
  292. log.Printf("ldap: recovered panic in processMessages: %v", err)
  293. }
  294. for messageID, msgCtx := range l.messageContexts {
  295. // If we are closing due to an error, inform anyone who
  296. // is waiting about the error.
  297. if l.isClosing() && l.closeErr.Load() != nil {
  298. msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
  299. }
  300. l.Debug.Printf("Closing channel for MessageID %d", messageID)
  301. close(msgCtx.responses)
  302. delete(l.messageContexts, messageID)
  303. }
  304. close(l.chanMessageID)
  305. close(l.chanConfirm)
  306. }()
  307. var messageID int64 = 1
  308. for {
  309. select {
  310. case l.chanMessageID <- messageID:
  311. messageID++
  312. case message := <-l.chanMessage:
  313. switch message.Op {
  314. case MessageQuit:
  315. l.Debug.Printf("Shutting down - quit message received")
  316. return
  317. case MessageRequest:
  318. // Add to message list and write to network
  319. l.Debug.Printf("Sending message %d", message.MessageID)
  320. buf := message.Packet.Bytes()
  321. _, err := l.conn.Write(buf)
  322. if err != nil {
  323. l.Debug.Printf("Error Sending Message: %s", err.Error())
  324. message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
  325. close(message.Context.responses)
  326. break
  327. }
  328. // Only add to messageContexts if we were able to
  329. // successfully write the message.
  330. l.messageContexts[message.MessageID] = message.Context
  331. // Add timeout if defined
  332. requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
  333. if requestTimeout > 0 {
  334. go func() {
  335. defer func() {
  336. if err := recover(); err != nil {
  337. log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
  338. }
  339. }()
  340. time.Sleep(requestTimeout)
  341. timeoutMessage := &messagePacket{
  342. Op: MessageTimeout,
  343. MessageID: message.MessageID,
  344. }
  345. l.sendProcessMessage(timeoutMessage)
  346. }()
  347. }
  348. case MessageResponse:
  349. l.Debug.Printf("Receiving message %d", message.MessageID)
  350. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  351. msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
  352. } else {
  353. log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing())
  354. ber.PrintPacket(message.Packet)
  355. }
  356. case MessageTimeout:
  357. // Handle the timeout by closing the channel
  358. // All reads will return immediately
  359. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  360. l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
  361. msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
  362. delete(l.messageContexts, message.MessageID)
  363. close(msgCtx.responses)
  364. }
  365. case MessageFinish:
  366. l.Debug.Printf("Finished message %d", message.MessageID)
  367. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  368. delete(l.messageContexts, message.MessageID)
  369. close(msgCtx.responses)
  370. }
  371. }
  372. }
  373. }
  374. }
  375. func (l *Conn) reader() {
  376. cleanstop := false
  377. defer func() {
  378. if err := recover(); err != nil {
  379. log.Printf("ldap: recovered panic in reader: %v", err)
  380. }
  381. if !cleanstop {
  382. l.Close()
  383. }
  384. }()
  385. for {
  386. if cleanstop {
  387. l.Debug.Printf("reader clean stopping (without closing the connection)")
  388. return
  389. }
  390. packet, err := ber.ReadPacket(l.conn)
  391. if err != nil {
  392. // A read error is expected here if we are closing the connection...
  393. if !l.isClosing() {
  394. l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
  395. l.Debug.Printf("reader error: %s", err.Error())
  396. }
  397. return
  398. }
  399. addLDAPDescriptions(packet)
  400. if len(packet.Children) == 0 {
  401. l.Debug.Printf("Received bad ldap packet")
  402. continue
  403. }
  404. l.messageMutex.Lock()
  405. if l.isStartingTLS {
  406. cleanstop = true
  407. }
  408. l.messageMutex.Unlock()
  409. message := &messagePacket{
  410. Op: MessageResponse,
  411. MessageID: packet.Children[0].Value.(int64),
  412. Packet: packet,
  413. }
  414. if !l.sendProcessMessage(message) {
  415. return
  416. }
  417. }
  418. }