tds.go 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367
  1. package mssql
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "io/ioutil"
  11. "net"
  12. "net/url"
  13. "os"
  14. "sort"
  15. "strconv"
  16. "strings"
  17. "time"
  18. "unicode"
  19. "unicode/utf16"
  20. "unicode/utf8"
  21. )
  22. func parseInstances(msg []byte) map[string]map[string]string {
  23. results := map[string]map[string]string{}
  24. if len(msg) > 3 && msg[0] == 5 {
  25. out_s := string(msg[3:])
  26. tokens := strings.Split(out_s, ";")
  27. instdict := map[string]string{}
  28. got_name := false
  29. var name string
  30. for _, token := range tokens {
  31. if got_name {
  32. instdict[name] = token
  33. got_name = false
  34. } else {
  35. name = token
  36. if len(name) == 0 {
  37. if len(instdict) == 0 {
  38. break
  39. }
  40. results[strings.ToUpper(instdict["InstanceName"])] = instdict
  41. instdict = map[string]string{}
  42. continue
  43. }
  44. got_name = true
  45. }
  46. }
  47. }
  48. return results
  49. }
  50. func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) {
  51. maxTime := 5 * time.Second
  52. ctx, cancel := context.WithTimeout(ctx, maxTime)
  53. defer cancel()
  54. conn, err := d.DialContext(ctx, "udp", address+":1434")
  55. if err != nil {
  56. return nil, err
  57. }
  58. defer conn.Close()
  59. conn.SetDeadline(time.Now().Add(maxTime))
  60. _, err = conn.Write([]byte{3})
  61. if err != nil {
  62. return nil, err
  63. }
  64. var resp = make([]byte, 16*1024-1)
  65. read, err := conn.Read(resp)
  66. if err != nil {
  67. return nil, err
  68. }
  69. return parseInstances(resp[:read]), nil
  70. }
  71. // tds versions
  72. const (
  73. verTDS70 = 0x70000000
  74. verTDS71 = 0x71000000
  75. verTDS71rev1 = 0x71000001
  76. verTDS72 = 0x72090002
  77. verTDS73A = 0x730A0003
  78. verTDS73 = verTDS73A
  79. verTDS73B = 0x730B0003
  80. verTDS74 = 0x74000004
  81. )
  82. // packet types
  83. // https://msdn.microsoft.com/en-us/library/dd304214.aspx
  84. const (
  85. packSQLBatch packetType = 1
  86. packRPCRequest = 3
  87. packReply = 4
  88. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  89. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  90. packAttention = 6
  91. packBulkLoadBCP = 7
  92. packTransMgrReq = 14
  93. packNormal = 15
  94. packLogin7 = 16
  95. packSSPIMessage = 17
  96. packPrelogin = 18
  97. )
  98. // prelogin fields
  99. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  100. const (
  101. preloginVERSION = 0
  102. preloginENCRYPTION = 1
  103. preloginINSTOPT = 2
  104. preloginTHREADID = 3
  105. preloginMARS = 4
  106. preloginTRACEID = 5
  107. preloginTERMINATOR = 0xff
  108. )
  109. const (
  110. encryptOff = 0 // Encryption is available but off.
  111. encryptOn = 1 // Encryption is available and on.
  112. encryptNotSup = 2 // Encryption is not available.
  113. encryptReq = 3 // Encryption is required.
  114. )
  115. type tdsSession struct {
  116. buf *tdsBuffer
  117. loginAck loginAckStruct
  118. database string
  119. partner string
  120. columns []columnStruct
  121. tranid uint64
  122. logFlags uint64
  123. log optionalLogger
  124. routedServer string
  125. routedPort uint16
  126. }
  127. const (
  128. logErrors = 1
  129. logMessages = 2
  130. logRows = 4
  131. logSQL = 8
  132. logParams = 16
  133. logTransaction = 32
  134. logDebug = 64
  135. )
  136. type columnStruct struct {
  137. UserType uint32
  138. Flags uint16
  139. ColName string
  140. ti typeInfo
  141. }
  142. type keySlice []uint8
  143. func (p keySlice) Len() int { return len(p) }
  144. func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
  145. func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
  146. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  147. func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
  148. var err error
  149. w.BeginPacket(packPrelogin, false)
  150. offset := uint16(5*len(fields) + 1)
  151. keys := make(keySlice, 0, len(fields))
  152. for k, _ := range fields {
  153. keys = append(keys, k)
  154. }
  155. sort.Sort(keys)
  156. // writing header
  157. for _, k := range keys {
  158. err = w.WriteByte(k)
  159. if err != nil {
  160. return err
  161. }
  162. err = binary.Write(w, binary.BigEndian, offset)
  163. if err != nil {
  164. return err
  165. }
  166. v := fields[k]
  167. size := uint16(len(v))
  168. err = binary.Write(w, binary.BigEndian, size)
  169. if err != nil {
  170. return err
  171. }
  172. offset += size
  173. }
  174. err = w.WriteByte(preloginTERMINATOR)
  175. if err != nil {
  176. return err
  177. }
  178. // writing values
  179. for _, k := range keys {
  180. v := fields[k]
  181. written, err := w.Write(v)
  182. if err != nil {
  183. return err
  184. }
  185. if written != len(v) {
  186. return errors.New("Write method didn't write the whole value")
  187. }
  188. }
  189. return w.FinishPacket()
  190. }
  191. func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
  192. packet_type, err := r.BeginRead()
  193. if err != nil {
  194. return nil, err
  195. }
  196. struct_buf, err := ioutil.ReadAll(r)
  197. if err != nil {
  198. return nil, err
  199. }
  200. if packet_type != 4 {
  201. return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
  202. }
  203. offset := 0
  204. results := map[uint8][]byte{}
  205. for true {
  206. rec_type := struct_buf[offset]
  207. if rec_type == preloginTERMINATOR {
  208. break
  209. }
  210. rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
  211. rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
  212. value := struct_buf[rec_offset : rec_offset+rec_len]
  213. results[rec_type] = value
  214. offset += 5
  215. }
  216. return results, nil
  217. }
  218. // OptionFlags2
  219. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  220. const (
  221. fLanguageFatal = 1
  222. fODBC = 2
  223. fTransBoundary = 4
  224. fCacheConnect = 8
  225. fIntSecurity = 0x80
  226. )
  227. // TypeFlags
  228. const (
  229. // 4 bits for fSQLType
  230. // 1 bit for fOLEDB
  231. fReadOnlyIntent = 32
  232. )
  233. type login struct {
  234. TDSVersion uint32
  235. PacketSize uint32
  236. ClientProgVer uint32
  237. ClientPID uint32
  238. ConnectionID uint32
  239. OptionFlags1 uint8
  240. OptionFlags2 uint8
  241. TypeFlags uint8
  242. OptionFlags3 uint8
  243. ClientTimeZone int32
  244. ClientLCID uint32
  245. HostName string
  246. UserName string
  247. Password string
  248. AppName string
  249. ServerName string
  250. CtlIntName string
  251. Language string
  252. Database string
  253. ClientID [6]byte
  254. SSPI []byte
  255. AtchDBFile string
  256. ChangePassword string
  257. }
  258. type loginHeader struct {
  259. Length uint32
  260. TDSVersion uint32
  261. PacketSize uint32
  262. ClientProgVer uint32
  263. ClientPID uint32
  264. ConnectionID uint32
  265. OptionFlags1 uint8
  266. OptionFlags2 uint8
  267. TypeFlags uint8
  268. OptionFlags3 uint8
  269. ClientTimeZone int32
  270. ClientLCID uint32
  271. HostNameOffset uint16
  272. HostNameLength uint16
  273. UserNameOffset uint16
  274. UserNameLength uint16
  275. PasswordOffset uint16
  276. PasswordLength uint16
  277. AppNameOffset uint16
  278. AppNameLength uint16
  279. ServerNameOffset uint16
  280. ServerNameLength uint16
  281. ExtensionOffset uint16
  282. ExtensionLenght uint16
  283. CtlIntNameOffset uint16
  284. CtlIntNameLength uint16
  285. LanguageOffset uint16
  286. LanguageLength uint16
  287. DatabaseOffset uint16
  288. DatabaseLength uint16
  289. ClientID [6]byte
  290. SSPIOffset uint16
  291. SSPILength uint16
  292. AtchDBFileOffset uint16
  293. AtchDBFileLength uint16
  294. ChangePasswordOffset uint16
  295. ChangePasswordLength uint16
  296. SSPILongLength uint32
  297. }
  298. // convert Go string to UTF-16 encoded []byte (littleEndian)
  299. // done manually rather than using bytes and binary packages
  300. // for performance reasons
  301. func str2ucs2(s string) []byte {
  302. res := utf16.Encode([]rune(s))
  303. ucs2 := make([]byte, 2*len(res))
  304. for i := 0; i < len(res); i++ {
  305. ucs2[2*i] = byte(res[i])
  306. ucs2[2*i+1] = byte(res[i] >> 8)
  307. }
  308. return ucs2
  309. }
  310. func ucs22str(s []byte) (string, error) {
  311. if len(s)%2 != 0 {
  312. return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
  313. }
  314. buf := make([]uint16, len(s)/2)
  315. for i := 0; i < len(s); i += 2 {
  316. buf[i/2] = binary.LittleEndian.Uint16(s[i:])
  317. }
  318. return string(utf16.Decode(buf)), nil
  319. }
  320. func manglePassword(password string) []byte {
  321. var ucs2password []byte = str2ucs2(password)
  322. for i, ch := range ucs2password {
  323. ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
  324. }
  325. return ucs2password
  326. }
  327. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  328. func sendLogin(w *tdsBuffer, login login) error {
  329. w.BeginPacket(packLogin7, false)
  330. hostname := str2ucs2(login.HostName)
  331. username := str2ucs2(login.UserName)
  332. password := manglePassword(login.Password)
  333. appname := str2ucs2(login.AppName)
  334. servername := str2ucs2(login.ServerName)
  335. ctlintname := str2ucs2(login.CtlIntName)
  336. language := str2ucs2(login.Language)
  337. database := str2ucs2(login.Database)
  338. atchdbfile := str2ucs2(login.AtchDBFile)
  339. changepassword := str2ucs2(login.ChangePassword)
  340. hdr := loginHeader{
  341. TDSVersion: login.TDSVersion,
  342. PacketSize: login.PacketSize,
  343. ClientProgVer: login.ClientProgVer,
  344. ClientPID: login.ClientPID,
  345. ConnectionID: login.ConnectionID,
  346. OptionFlags1: login.OptionFlags1,
  347. OptionFlags2: login.OptionFlags2,
  348. TypeFlags: login.TypeFlags,
  349. OptionFlags3: login.OptionFlags3,
  350. ClientTimeZone: login.ClientTimeZone,
  351. ClientLCID: login.ClientLCID,
  352. HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
  353. UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
  354. PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
  355. AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
  356. ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
  357. CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
  358. LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
  359. DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
  360. ClientID: login.ClientID,
  361. SSPILength: uint16(len(login.SSPI)),
  362. AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
  363. ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
  364. }
  365. offset := uint16(binary.Size(hdr))
  366. hdr.HostNameOffset = offset
  367. offset += uint16(len(hostname))
  368. hdr.UserNameOffset = offset
  369. offset += uint16(len(username))
  370. hdr.PasswordOffset = offset
  371. offset += uint16(len(password))
  372. hdr.AppNameOffset = offset
  373. offset += uint16(len(appname))
  374. hdr.ServerNameOffset = offset
  375. offset += uint16(len(servername))
  376. hdr.CtlIntNameOffset = offset
  377. offset += uint16(len(ctlintname))
  378. hdr.LanguageOffset = offset
  379. offset += uint16(len(language))
  380. hdr.DatabaseOffset = offset
  381. offset += uint16(len(database))
  382. hdr.SSPIOffset = offset
  383. offset += uint16(len(login.SSPI))
  384. hdr.AtchDBFileOffset = offset
  385. offset += uint16(len(atchdbfile))
  386. hdr.ChangePasswordOffset = offset
  387. offset += uint16(len(changepassword))
  388. hdr.Length = uint32(offset)
  389. var err error
  390. err = binary.Write(w, binary.LittleEndian, &hdr)
  391. if err != nil {
  392. return err
  393. }
  394. _, err = w.Write(hostname)
  395. if err != nil {
  396. return err
  397. }
  398. _, err = w.Write(username)
  399. if err != nil {
  400. return err
  401. }
  402. _, err = w.Write(password)
  403. if err != nil {
  404. return err
  405. }
  406. _, err = w.Write(appname)
  407. if err != nil {
  408. return err
  409. }
  410. _, err = w.Write(servername)
  411. if err != nil {
  412. return err
  413. }
  414. _, err = w.Write(ctlintname)
  415. if err != nil {
  416. return err
  417. }
  418. _, err = w.Write(language)
  419. if err != nil {
  420. return err
  421. }
  422. _, err = w.Write(database)
  423. if err != nil {
  424. return err
  425. }
  426. _, err = w.Write(login.SSPI)
  427. if err != nil {
  428. return err
  429. }
  430. _, err = w.Write(atchdbfile)
  431. if err != nil {
  432. return err
  433. }
  434. _, err = w.Write(changepassword)
  435. if err != nil {
  436. return err
  437. }
  438. return w.FinishPacket()
  439. }
  440. func readUcs2(r io.Reader, numchars int) (res string, err error) {
  441. buf := make([]byte, numchars*2)
  442. _, err = io.ReadFull(r, buf)
  443. if err != nil {
  444. return "", err
  445. }
  446. return ucs22str(buf)
  447. }
  448. func readUsVarChar(r io.Reader) (res string, err error) {
  449. var numchars uint16
  450. err = binary.Read(r, binary.LittleEndian, &numchars)
  451. if err != nil {
  452. return "", err
  453. }
  454. return readUcs2(r, int(numchars))
  455. }
  456. func writeUsVarChar(w io.Writer, s string) (err error) {
  457. buf := str2ucs2(s)
  458. var numchars int = len(buf) / 2
  459. if numchars > 0xffff {
  460. panic("invalid size for US_VARCHAR")
  461. }
  462. err = binary.Write(w, binary.LittleEndian, uint16(numchars))
  463. if err != nil {
  464. return
  465. }
  466. _, err = w.Write(buf)
  467. return
  468. }
  469. func readBVarChar(r io.Reader) (res string, err error) {
  470. var numchars uint8
  471. err = binary.Read(r, binary.LittleEndian, &numchars)
  472. if err != nil {
  473. return "", err
  474. }
  475. // A zero length could be returned, return an empty string
  476. if numchars == 0 {
  477. return "", nil
  478. }
  479. return readUcs2(r, int(numchars))
  480. }
  481. func writeBVarChar(w io.Writer, s string) (err error) {
  482. buf := str2ucs2(s)
  483. var numchars int = len(buf) / 2
  484. if numchars > 0xff {
  485. panic("invalid size for B_VARCHAR")
  486. }
  487. err = binary.Write(w, binary.LittleEndian, uint8(numchars))
  488. if err != nil {
  489. return
  490. }
  491. _, err = w.Write(buf)
  492. return
  493. }
  494. func readBVarByte(r io.Reader) (res []byte, err error) {
  495. var length uint8
  496. err = binary.Read(r, binary.LittleEndian, &length)
  497. if err != nil {
  498. return
  499. }
  500. res = make([]byte, length)
  501. _, err = io.ReadFull(r, res)
  502. return
  503. }
  504. func readUshort(r io.Reader) (res uint16, err error) {
  505. err = binary.Read(r, binary.LittleEndian, &res)
  506. return
  507. }
  508. func readByte(r io.Reader) (res byte, err error) {
  509. var b [1]byte
  510. _, err = r.Read(b[:])
  511. res = b[0]
  512. return
  513. }
  514. // Packet Data Stream Headers
  515. // http://msdn.microsoft.com/en-us/library/dd304953.aspx
  516. type headerStruct struct {
  517. hdrtype uint16
  518. data []byte
  519. }
  520. const (
  521. dataStmHdrQueryNotif = 1 // query notifications
  522. dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
  523. dataStmHdrTraceActivity = 3
  524. )
  525. // Query Notifications Header
  526. // http://msdn.microsoft.com/en-us/library/dd304949.aspx
  527. type queryNotifHdr struct {
  528. notifyId string
  529. ssbDeployment string
  530. notifyTimeout uint32
  531. }
  532. func (hdr queryNotifHdr) pack() (res []byte) {
  533. notifyId := str2ucs2(hdr.notifyId)
  534. ssbDeployment := str2ucs2(hdr.ssbDeployment)
  535. res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
  536. b := res
  537. binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
  538. b = b[2:]
  539. copy(b, notifyId)
  540. b = b[len(notifyId):]
  541. binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
  542. b = b[2:]
  543. copy(b, ssbDeployment)
  544. b = b[len(ssbDeployment):]
  545. binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
  546. return res
  547. }
  548. // MARS Transaction Descriptor Header
  549. // http://msdn.microsoft.com/en-us/library/dd340515.aspx
  550. type transDescrHdr struct {
  551. transDescr uint64 // transaction descriptor returned from ENVCHANGE
  552. outstandingReqCnt uint32 // outstanding request count
  553. }
  554. func (hdr transDescrHdr) pack() (res []byte) {
  555. res = make([]byte, 8+4)
  556. binary.LittleEndian.PutUint64(res, hdr.transDescr)
  557. binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
  558. return res
  559. }
  560. func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
  561. // Calculating total length.
  562. var totallen uint32 = 4
  563. for _, hdr := range headers {
  564. totallen += 4 + 2 + uint32(len(hdr.data))
  565. }
  566. // writing
  567. err = binary.Write(w, binary.LittleEndian, totallen)
  568. if err != nil {
  569. return err
  570. }
  571. for _, hdr := range headers {
  572. var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
  573. err = binary.Write(w, binary.LittleEndian, headerlen)
  574. if err != nil {
  575. return err
  576. }
  577. err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
  578. if err != nil {
  579. return err
  580. }
  581. _, err = w.Write(hdr.data)
  582. if err != nil {
  583. return err
  584. }
  585. }
  586. return nil
  587. }
  588. func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
  589. buf.BeginPacket(packSQLBatch, resetSession)
  590. if err = writeAllHeaders(buf, headers); err != nil {
  591. return
  592. }
  593. _, err = buf.Write(str2ucs2(sqltext))
  594. if err != nil {
  595. return
  596. }
  597. return buf.FinishPacket()
  598. }
  599. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  600. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  601. func sendAttention(buf *tdsBuffer) error {
  602. buf.BeginPacket(packAttention, false)
  603. return buf.FinishPacket()
  604. }
  605. type connectParams struct {
  606. logFlags uint64
  607. port uint64
  608. host string
  609. instance string
  610. database string
  611. user string
  612. password string
  613. dial_timeout time.Duration
  614. conn_timeout time.Duration
  615. keepAlive time.Duration
  616. encrypt bool
  617. disableEncryption bool
  618. trustServerCertificate bool
  619. certificate string
  620. hostInCertificate string
  621. serverSPN string
  622. workstation string
  623. appname string
  624. typeFlags uint8
  625. failOverPartner string
  626. failOverPort uint64
  627. packetSize uint16
  628. }
  629. func splitConnectionString(dsn string) (res map[string]string) {
  630. res = map[string]string{}
  631. parts := strings.Split(dsn, ";")
  632. for _, part := range parts {
  633. if len(part) == 0 {
  634. continue
  635. }
  636. lst := strings.SplitN(part, "=", 2)
  637. name := strings.TrimSpace(strings.ToLower(lst[0]))
  638. if len(name) == 0 {
  639. continue
  640. }
  641. var value string = ""
  642. if len(lst) > 1 {
  643. value = strings.TrimSpace(lst[1])
  644. }
  645. res[name] = value
  646. }
  647. return res
  648. }
  649. // Splits a URL in the ODBC format
  650. func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
  651. res := map[string]string{}
  652. type parserState int
  653. const (
  654. // Before the start of a key
  655. parserStateBeforeKey parserState = iota
  656. // Inside a key
  657. parserStateKey
  658. // Beginning of a value. May be bare or braced
  659. parserStateBeginValue
  660. // Inside a bare value
  661. parserStateBareValue
  662. // Inside a braced value
  663. parserStateBracedValue
  664. // A closing brace inside a braced value.
  665. // May be the end of the value or an escaped closing brace, depending on the next character
  666. parserStateBracedValueClosingBrace
  667. // After a value. Next character should be a semicolon or whitespace.
  668. parserStateEndValue
  669. )
  670. var state = parserStateBeforeKey
  671. var key string
  672. var value string
  673. for i, c := range dsn {
  674. switch state {
  675. case parserStateBeforeKey:
  676. switch {
  677. case c == '=':
  678. return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
  679. case !unicode.IsSpace(c) && c != ';':
  680. state = parserStateKey
  681. key += string(c)
  682. }
  683. case parserStateKey:
  684. switch c {
  685. case '=':
  686. key = normalizeOdbcKey(key)
  687. if len(key) == 0 {
  688. return res, fmt.Errorf("Unexpected end of key at index %d.", i)
  689. }
  690. state = parserStateBeginValue
  691. case ';':
  692. // Key without value
  693. key = normalizeOdbcKey(key)
  694. if len(key) == 0 {
  695. return res, fmt.Errorf("Unexpected end of key at index %d.", i)
  696. }
  697. res[key] = value
  698. key = ""
  699. value = ""
  700. state = parserStateBeforeKey
  701. default:
  702. key += string(c)
  703. }
  704. case parserStateBeginValue:
  705. switch {
  706. case c == '{':
  707. state = parserStateBracedValue
  708. case c == ';':
  709. // Empty value
  710. res[key] = value
  711. key = ""
  712. state = parserStateBeforeKey
  713. case unicode.IsSpace(c):
  714. // Ignore whitespace
  715. default:
  716. state = parserStateBareValue
  717. value += string(c)
  718. }
  719. case parserStateBareValue:
  720. if c == ';' {
  721. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  722. key = ""
  723. value = ""
  724. state = parserStateBeforeKey
  725. } else {
  726. value += string(c)
  727. }
  728. case parserStateBracedValue:
  729. if c == '}' {
  730. state = parserStateBracedValueClosingBrace
  731. } else {
  732. value += string(c)
  733. }
  734. case parserStateBracedValueClosingBrace:
  735. if c == '}' {
  736. // Escaped closing brace
  737. value += string(c)
  738. state = parserStateBracedValue
  739. continue
  740. }
  741. // End of braced value
  742. res[key] = value
  743. key = ""
  744. value = ""
  745. // This character is the first character past the end,
  746. // so it needs to be parsed like the parserStateEndValue state.
  747. state = parserStateEndValue
  748. switch {
  749. case c == ';':
  750. state = parserStateBeforeKey
  751. case unicode.IsSpace(c):
  752. // Ignore whitespace
  753. default:
  754. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  755. }
  756. case parserStateEndValue:
  757. switch {
  758. case c == ';':
  759. state = parserStateBeforeKey
  760. case unicode.IsSpace(c):
  761. // Ignore whitespace
  762. default:
  763. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  764. }
  765. }
  766. }
  767. switch state {
  768. case parserStateBeforeKey: // Okay
  769. case parserStateKey: // Unfinished key. Treat as key without value.
  770. key = normalizeOdbcKey(key)
  771. if len(key) == 0 {
  772. return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
  773. }
  774. res[key] = value
  775. case parserStateBeginValue: // Empty value
  776. res[key] = value
  777. case parserStateBareValue:
  778. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  779. case parserStateBracedValue:
  780. return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
  781. case parserStateBracedValueClosingBrace: // End of braced value
  782. res[key] = value
  783. case parserStateEndValue: // Okay
  784. }
  785. return res, nil
  786. }
  787. // Normalizes the given string as an ODBC-format key
  788. func normalizeOdbcKey(s string) string {
  789. return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
  790. }
  791. // Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
  792. func splitConnectionStringURL(dsn string) (map[string]string, error) {
  793. res := map[string]string{}
  794. u, err := url.Parse(dsn)
  795. if err != nil {
  796. return res, err
  797. }
  798. if u.Scheme != "sqlserver" {
  799. return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
  800. }
  801. if u.User != nil {
  802. res["user id"] = u.User.Username()
  803. p, exists := u.User.Password()
  804. if exists {
  805. res["password"] = p
  806. }
  807. }
  808. host, port, err := net.SplitHostPort(u.Host)
  809. if err != nil {
  810. host = u.Host
  811. }
  812. if len(u.Path) > 0 {
  813. res["server"] = host + "\\" + u.Path[1:]
  814. } else {
  815. res["server"] = host
  816. }
  817. if len(port) > 0 {
  818. res["port"] = port
  819. }
  820. query := u.Query()
  821. for k, v := range query {
  822. if len(v) > 1 {
  823. return res, fmt.Errorf("key %s provided more than once", k)
  824. }
  825. res[strings.ToLower(k)] = v[0]
  826. }
  827. return res, nil
  828. }
  829. func parseConnectParams(dsn string) (connectParams, error) {
  830. var p connectParams
  831. var params map[string]string
  832. if strings.HasPrefix(dsn, "odbc:") {
  833. parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
  834. if err != nil {
  835. return p, err
  836. }
  837. params = parameters
  838. } else if strings.HasPrefix(dsn, "sqlserver://") {
  839. parameters, err := splitConnectionStringURL(dsn)
  840. if err != nil {
  841. return p, err
  842. }
  843. params = parameters
  844. } else {
  845. params = splitConnectionString(dsn)
  846. }
  847. strlog, ok := params["log"]
  848. if ok {
  849. var err error
  850. p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
  851. if err != nil {
  852. return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
  853. }
  854. }
  855. server := params["server"]
  856. parts := strings.SplitN(server, `\`, 2)
  857. p.host = parts[0]
  858. if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
  859. p.host = "localhost"
  860. }
  861. if len(parts) > 1 {
  862. p.instance = parts[1]
  863. }
  864. p.database = params["database"]
  865. p.user = params["user id"]
  866. p.password = params["password"]
  867. p.port = 1433
  868. strport, ok := params["port"]
  869. if ok {
  870. var err error
  871. p.port, err = strconv.ParseUint(strport, 10, 16)
  872. if err != nil {
  873. f := "Invalid tcp port '%v': %v"
  874. return p, fmt.Errorf(f, strport, err.Error())
  875. }
  876. }
  877. // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
  878. // Default packet size remains at 4096 bytes
  879. p.packetSize = 4096
  880. strpsize, ok := params["packet size"]
  881. if ok {
  882. var err error
  883. psize, err := strconv.ParseUint(strpsize, 0, 16)
  884. if err != nil {
  885. f := "Invalid packet size '%v': %v"
  886. return p, fmt.Errorf(f, strpsize, err.Error())
  887. }
  888. // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
  889. // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
  890. // a higher packet size, the server will respond with an ENVCHANGE request to
  891. // alter the packet size to 16383 bytes.
  892. p.packetSize = uint16(psize)
  893. if p.packetSize < 512 {
  894. p.packetSize = 512
  895. } else if p.packetSize > 32767 {
  896. p.packetSize = 32767
  897. }
  898. }
  899. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  900. //
  901. // Do not set a connection timeout. Use Context to manage such things.
  902. // Default to zero, but still allow it to be set.
  903. if strconntimeout, ok := params["connection timeout"]; ok {
  904. timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
  905. if err != nil {
  906. f := "Invalid connection timeout '%v': %v"
  907. return p, fmt.Errorf(f, strconntimeout, err.Error())
  908. }
  909. p.conn_timeout = time.Duration(timeout) * time.Second
  910. }
  911. p.dial_timeout = 15 * time.Second
  912. if strdialtimeout, ok := params["dial timeout"]; ok {
  913. timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
  914. if err != nil {
  915. f := "Invalid dial timeout '%v': %v"
  916. return p, fmt.Errorf(f, strdialtimeout, err.Error())
  917. }
  918. p.dial_timeout = time.Duration(timeout) * time.Second
  919. }
  920. // default keep alive should be 30 seconds according to spec:
  921. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  922. p.keepAlive = 30 * time.Second
  923. if keepAlive, ok := params["keepalive"]; ok {
  924. timeout, err := strconv.ParseUint(keepAlive, 10, 64)
  925. if err != nil {
  926. f := "Invalid keepAlive value '%s': %s"
  927. return p, fmt.Errorf(f, keepAlive, err.Error())
  928. }
  929. p.keepAlive = time.Duration(timeout) * time.Second
  930. }
  931. encrypt, ok := params["encrypt"]
  932. if ok {
  933. if strings.EqualFold(encrypt, "DISABLE") {
  934. p.disableEncryption = true
  935. } else {
  936. var err error
  937. p.encrypt, err = strconv.ParseBool(encrypt)
  938. if err != nil {
  939. f := "Invalid encrypt '%s': %s"
  940. return p, fmt.Errorf(f, encrypt, err.Error())
  941. }
  942. }
  943. } else {
  944. p.trustServerCertificate = true
  945. }
  946. trust, ok := params["trustservercertificate"]
  947. if ok {
  948. var err error
  949. p.trustServerCertificate, err = strconv.ParseBool(trust)
  950. if err != nil {
  951. f := "Invalid trust server certificate '%s': %s"
  952. return p, fmt.Errorf(f, trust, err.Error())
  953. }
  954. }
  955. p.certificate = params["certificate"]
  956. p.hostInCertificate, ok = params["hostnameincertificate"]
  957. if !ok {
  958. p.hostInCertificate = p.host
  959. }
  960. serverSPN, ok := params["serverspn"]
  961. if ok {
  962. p.serverSPN = serverSPN
  963. } else {
  964. p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
  965. }
  966. workstation, ok := params["workstation id"]
  967. if ok {
  968. p.workstation = workstation
  969. } else {
  970. workstation, err := os.Hostname()
  971. if err == nil {
  972. p.workstation = workstation
  973. }
  974. }
  975. appname, ok := params["app name"]
  976. if !ok {
  977. appname = "go-mssqldb"
  978. }
  979. p.appname = appname
  980. appintent, ok := params["applicationintent"]
  981. if ok {
  982. if appintent == "ReadOnly" {
  983. p.typeFlags |= fReadOnlyIntent
  984. }
  985. }
  986. failOverPartner, ok := params["failoverpartner"]
  987. if ok {
  988. p.failOverPartner = failOverPartner
  989. }
  990. failOverPort, ok := params["failoverport"]
  991. if ok {
  992. var err error
  993. p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
  994. if err != nil {
  995. f := "Invalid tcp port '%v': %v"
  996. return p, fmt.Errorf(f, failOverPort, err.Error())
  997. }
  998. }
  999. return p, nil
  1000. }
  1001. type auth interface {
  1002. InitialBytes() ([]byte, error)
  1003. NextBytes([]byte) ([]byte, error)
  1004. Free()
  1005. }
  1006. // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
  1007. // list of IP addresses. So if there is more than one, try them all and
  1008. // use the first one that allows a connection.
  1009. func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
  1010. var ips []net.IP
  1011. ips, err = net.LookupIP(p.host)
  1012. if err != nil {
  1013. ip := net.ParseIP(p.host)
  1014. if ip == nil {
  1015. return nil, err
  1016. }
  1017. ips = []net.IP{ip}
  1018. }
  1019. if len(ips) == 1 {
  1020. d := c.getDialer(&p)
  1021. addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
  1022. conn, err = d.DialContext(ctx, "tcp", addr)
  1023. } else {
  1024. //Try Dials in parallel to avoid waiting for timeouts.
  1025. connChan := make(chan net.Conn, len(ips))
  1026. errChan := make(chan error, len(ips))
  1027. portStr := strconv.Itoa(int(p.port))
  1028. for _, ip := range ips {
  1029. go func(ip net.IP) {
  1030. d := c.getDialer(&p)
  1031. addr := net.JoinHostPort(ip.String(), portStr)
  1032. conn, err := d.DialContext(ctx, "tcp", addr)
  1033. if err == nil {
  1034. connChan <- conn
  1035. } else {
  1036. errChan <- err
  1037. }
  1038. }(ip)
  1039. }
  1040. // Wait for either the *first* successful connection, or all the errors
  1041. wait_loop:
  1042. for i, _ := range ips {
  1043. select {
  1044. case conn = <-connChan:
  1045. // Got a connection to use, close any others
  1046. go func(n int) {
  1047. for i := 0; i < n; i++ {
  1048. select {
  1049. case conn := <-connChan:
  1050. conn.Close()
  1051. case <-errChan:
  1052. }
  1053. }
  1054. }(len(ips) - i - 1)
  1055. // Remove any earlier errors we may have collected
  1056. err = nil
  1057. break wait_loop
  1058. case err = <-errChan:
  1059. }
  1060. }
  1061. }
  1062. // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
  1063. if conn == nil {
  1064. f := "Unable to open tcp connection with host '%v:%v': %v"
  1065. return nil, fmt.Errorf(f, p.host, p.port, err.Error())
  1066. }
  1067. return conn, err
  1068. }
  1069. func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
  1070. dialCtx := ctx
  1071. if p.dial_timeout > 0 {
  1072. var cancel func()
  1073. dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
  1074. defer cancel()
  1075. }
  1076. // if instance is specified use instance resolution service
  1077. if p.instance != "" {
  1078. p.instance = strings.ToUpper(p.instance)
  1079. d := c.getDialer(&p)
  1080. instances, err := getInstances(dialCtx, d, p.host)
  1081. if err != nil {
  1082. f := "Unable to get instances from Sql Server Browser on host %v: %v"
  1083. return nil, fmt.Errorf(f, p.host, err.Error())
  1084. }
  1085. strport, ok := instances[p.instance]["tcp"]
  1086. if !ok {
  1087. f := "No instance matching '%v' returned from host '%v'"
  1088. return nil, fmt.Errorf(f, p.instance, p.host)
  1089. }
  1090. p.port, err = strconv.ParseUint(strport, 0, 16)
  1091. if err != nil {
  1092. f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
  1093. return nil, fmt.Errorf(f, strport, err.Error())
  1094. }
  1095. }
  1096. initiate_connection:
  1097. conn, err := dialConnection(dialCtx, c, p)
  1098. if err != nil {
  1099. return nil, err
  1100. }
  1101. toconn := newTimeoutConn(conn, p.conn_timeout)
  1102. outbuf := newTdsBuffer(p.packetSize, toconn)
  1103. sess := tdsSession{
  1104. buf: outbuf,
  1105. log: log,
  1106. logFlags: p.logFlags,
  1107. }
  1108. instance_buf := []byte(p.instance)
  1109. instance_buf = append(instance_buf, 0) // zero terminate instance name
  1110. var encrypt byte
  1111. if p.disableEncryption {
  1112. encrypt = encryptNotSup
  1113. } else if p.encrypt {
  1114. encrypt = encryptOn
  1115. } else {
  1116. encrypt = encryptOff
  1117. }
  1118. fields := map[uint8][]byte{
  1119. preloginVERSION: {0, 0, 0, 0, 0, 0},
  1120. preloginENCRYPTION: {encrypt},
  1121. preloginINSTOPT: instance_buf,
  1122. preloginTHREADID: {0, 0, 0, 0},
  1123. preloginMARS: {0}, // MARS disabled
  1124. }
  1125. err = writePrelogin(outbuf, fields)
  1126. if err != nil {
  1127. return nil, err
  1128. }
  1129. fields, err = readPrelogin(outbuf)
  1130. if err != nil {
  1131. return nil, err
  1132. }
  1133. encryptBytes, ok := fields[preloginENCRYPTION]
  1134. if !ok {
  1135. return nil, fmt.Errorf("Encrypt negotiation failed")
  1136. }
  1137. encrypt = encryptBytes[0]
  1138. if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
  1139. return nil, fmt.Errorf("Server does not support encryption")
  1140. }
  1141. if encrypt != encryptNotSup {
  1142. var config tls.Config
  1143. if p.certificate != "" {
  1144. pem, err := ioutil.ReadFile(p.certificate)
  1145. if err != nil {
  1146. return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
  1147. }
  1148. certs := x509.NewCertPool()
  1149. certs.AppendCertsFromPEM(pem)
  1150. config.RootCAs = certs
  1151. }
  1152. if p.trustServerCertificate {
  1153. config.InsecureSkipVerify = true
  1154. }
  1155. config.ServerName = p.hostInCertificate
  1156. // fix for https://github.com/denisenkom/go-mssqldb/issues/166
  1157. // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
  1158. // while SQL Server seems to expect one TCP segment per encrypted TDS package.
  1159. // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
  1160. config.DynamicRecordSizingDisabled = true
  1161. outbuf.transport = conn
  1162. toconn.buf = outbuf
  1163. tlsConn := tls.Client(toconn, &config)
  1164. err = tlsConn.Handshake()
  1165. toconn.buf = nil
  1166. outbuf.transport = tlsConn
  1167. if err != nil {
  1168. return nil, fmt.Errorf("TLS Handshake failed: %v", err)
  1169. }
  1170. if encrypt == encryptOff {
  1171. outbuf.afterFirst = func() {
  1172. outbuf.transport = toconn
  1173. }
  1174. }
  1175. }
  1176. login := login{
  1177. TDSVersion: verTDS74,
  1178. PacketSize: uint32(outbuf.PackageSize()),
  1179. Database: p.database,
  1180. OptionFlags2: fODBC, // to get unlimited TEXTSIZE
  1181. HostName: p.workstation,
  1182. ServerName: p.host,
  1183. AppName: p.appname,
  1184. TypeFlags: p.typeFlags,
  1185. }
  1186. auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
  1187. if auth_ok {
  1188. login.SSPI, err = auth.InitialBytes()
  1189. if err != nil {
  1190. return nil, err
  1191. }
  1192. login.OptionFlags2 |= fIntSecurity
  1193. defer auth.Free()
  1194. } else {
  1195. login.UserName = p.user
  1196. login.Password = p.password
  1197. }
  1198. err = sendLogin(outbuf, login)
  1199. if err != nil {
  1200. return nil, err
  1201. }
  1202. // processing login response
  1203. var sspi_msg []byte
  1204. continue_login:
  1205. tokchan := make(chan tokenStruct, 5)
  1206. go processResponse(context.Background(), &sess, tokchan, nil)
  1207. success := false
  1208. for tok := range tokchan {
  1209. switch token := tok.(type) {
  1210. case sspiMsg:
  1211. sspi_msg, err = auth.NextBytes(token)
  1212. if err != nil {
  1213. return nil, err
  1214. }
  1215. case loginAckStruct:
  1216. success = true
  1217. sess.loginAck = token
  1218. case error:
  1219. return nil, fmt.Errorf("Login error: %s", token.Error())
  1220. case doneStruct:
  1221. if token.isError() {
  1222. return nil, fmt.Errorf("Login error: %s", token.getError())
  1223. }
  1224. }
  1225. }
  1226. if sspi_msg != nil {
  1227. outbuf.BeginPacket(packSSPIMessage, false)
  1228. _, err = outbuf.Write(sspi_msg)
  1229. if err != nil {
  1230. return nil, err
  1231. }
  1232. err = outbuf.FinishPacket()
  1233. if err != nil {
  1234. return nil, err
  1235. }
  1236. sspi_msg = nil
  1237. goto continue_login
  1238. }
  1239. if !success {
  1240. return nil, fmt.Errorf("Login failed")
  1241. }
  1242. if sess.routedServer != "" {
  1243. toconn.Close()
  1244. p.host = sess.routedServer
  1245. p.port = uint64(sess.routedPort)
  1246. goto initiate_connection
  1247. }
  1248. return &sess, nil
  1249. }