|
|
@@ -10,18 +10,27 @@ import (
|
|
|
"errors"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
|
- "math/rand"
|
|
|
"net"
|
|
|
"strconv"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
+ "unicode/utf8"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
+ // Frame header byte 0 bits from Section 5.2 of RFC 6455
|
|
|
+ finalBit = 1 << 7
|
|
|
+ rsv1Bit = 1 << 6
|
|
|
+ rsv2Bit = 1 << 5
|
|
|
+ rsv3Bit = 1 << 4
|
|
|
+
|
|
|
+ // Frame header byte 1 bits from Section 5.2 of RFC 6455
|
|
|
+ maskBit = 1 << 7
|
|
|
+
|
|
|
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
|
|
|
maxControlFramePayloadSize = 125
|
|
|
- finalBit = 1 << 7
|
|
|
- maskBit = 1 << 7
|
|
|
- writeWait = time.Second
|
|
|
+
|
|
|
+ writeWait = time.Second
|
|
|
|
|
|
defaultReadBufferSize = 4096
|
|
|
defaultWriteBufferSize = 4096
|
|
|
@@ -43,6 +52,8 @@ const (
|
|
|
CloseMessageTooBig = 1009
|
|
|
CloseMandatoryExtension = 1010
|
|
|
CloseInternalServerErr = 1011
|
|
|
+ CloseServiceRestart = 1012
|
|
|
+ CloseTryAgainLater = 1013
|
|
|
CloseTLSHandshake = 1015
|
|
|
)
|
|
|
|
|
|
@@ -184,51 +195,65 @@ func isData(frameType int) bool {
|
|
|
return frameType == TextMessage || frameType == BinaryMessage
|
|
|
}
|
|
|
|
|
|
-func maskBytes(key [4]byte, pos int, b []byte) int {
|
|
|
- for i := range b {
|
|
|
- b[i] ^= key[pos&3]
|
|
|
- pos++
|
|
|
- }
|
|
|
- return pos & 3
|
|
|
+var validReceivedCloseCodes = map[int]bool{
|
|
|
+ // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
|
|
|
+
|
|
|
+ CloseNormalClosure: true,
|
|
|
+ CloseGoingAway: true,
|
|
|
+ CloseProtocolError: true,
|
|
|
+ CloseUnsupportedData: true,
|
|
|
+ CloseNoStatusReceived: false,
|
|
|
+ CloseAbnormalClosure: false,
|
|
|
+ CloseInvalidFramePayloadData: true,
|
|
|
+ ClosePolicyViolation: true,
|
|
|
+ CloseMessageTooBig: true,
|
|
|
+ CloseMandatoryExtension: true,
|
|
|
+ CloseInternalServerErr: true,
|
|
|
+ CloseServiceRestart: true,
|
|
|
+ CloseTryAgainLater: true,
|
|
|
+ CloseTLSHandshake: false,
|
|
|
}
|
|
|
|
|
|
-func newMaskKey() [4]byte {
|
|
|
- n := rand.Uint32()
|
|
|
- return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
|
|
|
+func isValidReceivedCloseCode(code int) bool {
|
|
|
+ return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
|
|
|
}
|
|
|
|
|
|
-// Conn represents a WebSocket connection.
|
|
|
+// The Conn type represents a WebSocket connection.
|
|
|
type Conn struct {
|
|
|
conn net.Conn
|
|
|
isServer bool
|
|
|
subprotocol string
|
|
|
|
|
|
// Write fields
|
|
|
- mu chan bool // used as mutex to protect write to conn and closeSent
|
|
|
- closeSent bool // true if close message was sent
|
|
|
-
|
|
|
- // Message writer fields.
|
|
|
- writeErr error
|
|
|
- writeBuf []byte // frame is constructed in this buffer.
|
|
|
- writePos int // end of data in writeBuf.
|
|
|
- writeFrameType int // type of the current frame.
|
|
|
- writeSeq int // incremented to invalidate message writers.
|
|
|
- writeDeadline time.Time
|
|
|
- isWriting bool // for best-effort concurrent write detection
|
|
|
+ mu chan bool // used as mutex to protect write to conn
|
|
|
+ writeBuf []byte // frame is constructed in this buffer.
|
|
|
+ writeDeadline time.Time
|
|
|
+ writer io.WriteCloser // the current writer returned to the application
|
|
|
+ isWriting bool // for best-effort concurrent write detection
|
|
|
+
|
|
|
+ writeErrMu sync.Mutex
|
|
|
+ writeErr error
|
|
|
+
|
|
|
+ enableWriteCompression bool
|
|
|
+ newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
|
|
|
|
|
|
// Read fields
|
|
|
readErr error
|
|
|
br *bufio.Reader
|
|
|
readRemaining int64 // bytes remaining in current frame.
|
|
|
readFinal bool // true the current message has more frames.
|
|
|
- readSeq int // incremented to invalidate message readers.
|
|
|
readLength int64 // Message size.
|
|
|
readLimit int64 // Maximum message size.
|
|
|
readMaskPos int
|
|
|
readMaskKey [4]byte
|
|
|
handlePong func(string) error
|
|
|
handlePing func(string) error
|
|
|
+ handleClose func(int, string) error
|
|
|
readErrCount int
|
|
|
+ messageReader *messageReader // the current low-level reader
|
|
|
+
|
|
|
+ readDecompress bool // whether last read frame had RSV1 set
|
|
|
+ newDecompressionReader func(io.Reader) io.Reader
|
|
|
}
|
|
|
|
|
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
|
|
@@ -238,20 +263,23 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
|
|
|
if readBufferSize == 0 {
|
|
|
readBufferSize = defaultReadBufferSize
|
|
|
}
|
|
|
+ if readBufferSize < maxControlFramePayloadSize {
|
|
|
+ readBufferSize = maxControlFramePayloadSize
|
|
|
+ }
|
|
|
if writeBufferSize == 0 {
|
|
|
writeBufferSize = defaultWriteBufferSize
|
|
|
}
|
|
|
|
|
|
c := &Conn{
|
|
|
- isServer: isServer,
|
|
|
- br: bufio.NewReaderSize(conn, readBufferSize),
|
|
|
- conn: conn,
|
|
|
- mu: mu,
|
|
|
- readFinal: true,
|
|
|
- writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
|
|
|
- writeFrameType: noFrame,
|
|
|
- writePos: maxFrameHeaderSize,
|
|
|
- }
|
|
|
+ isServer: isServer,
|
|
|
+ br: bufio.NewReaderSize(conn, readBufferSize),
|
|
|
+ conn: conn,
|
|
|
+ mu: mu,
|
|
|
+ readFinal: true,
|
|
|
+ writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
|
|
|
+ enableWriteCompression: true,
|
|
|
+ }
|
|
|
+ c.SetCloseHandler(nil)
|
|
|
c.SetPingHandler(nil)
|
|
|
c.SetPongHandler(nil)
|
|
|
return c
|
|
|
@@ -279,29 +307,40 @@ func (c *Conn) RemoteAddr() net.Addr {
|
|
|
|
|
|
// Write methods
|
|
|
|
|
|
+func (c *Conn) writeFatal(err error) error {
|
|
|
+ err = hideTempErr(err)
|
|
|
+ c.writeErrMu.Lock()
|
|
|
+ if c.writeErr == nil {
|
|
|
+ c.writeErr = err
|
|
|
+ }
|
|
|
+ c.writeErrMu.Unlock()
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
|
|
|
<-c.mu
|
|
|
defer func() { c.mu <- true }()
|
|
|
|
|
|
- if c.closeSent {
|
|
|
- return ErrCloseSent
|
|
|
- } else if frameType == CloseMessage {
|
|
|
- c.closeSent = true
|
|
|
+ c.writeErrMu.Lock()
|
|
|
+ err := c.writeErr
|
|
|
+ c.writeErrMu.Unlock()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
c.conn.SetWriteDeadline(deadline)
|
|
|
for _, buf := range bufs {
|
|
|
if len(buf) > 0 {
|
|
|
- n, err := c.conn.Write(buf)
|
|
|
- if n != len(buf) {
|
|
|
- // Close on partial write.
|
|
|
- c.conn.Close()
|
|
|
- }
|
|
|
+ _, err := c.conn.Write(buf)
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return c.writeFatal(err)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ if frameType == CloseMessage {
|
|
|
+ c.writeFatal(ErrCloseSent)
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -350,60 +389,108 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|
|
}
|
|
|
defer func() { c.mu <- true }()
|
|
|
|
|
|
- if c.closeSent {
|
|
|
- return ErrCloseSent
|
|
|
- } else if messageType == CloseMessage {
|
|
|
- c.closeSent = true
|
|
|
+ c.writeErrMu.Lock()
|
|
|
+ err := c.writeErr
|
|
|
+ c.writeErrMu.Unlock()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
c.conn.SetWriteDeadline(deadline)
|
|
|
- n, err := c.conn.Write(buf)
|
|
|
- if n != 0 && n != len(buf) {
|
|
|
- c.conn.Close()
|
|
|
+ _, err = c.conn.Write(buf)
|
|
|
+ if err != nil {
|
|
|
+ return c.writeFatal(err)
|
|
|
}
|
|
|
- return hideTempErr(err)
|
|
|
+ if messageType == CloseMessage {
|
|
|
+ c.writeFatal(ErrCloseSent)
|
|
|
+ }
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
-// NextWriter returns a writer for the next message to send. The writer's
|
|
|
-// Close method flushes the complete message to the network.
|
|
|
+func (c *Conn) prepWrite(messageType int) error {
|
|
|
+ // Close previous writer if not already closed by the application. It's
|
|
|
+ // probably better to return an error in this situation, but we cannot
|
|
|
+ // change this without breaking existing applications.
|
|
|
+ if c.writer != nil {
|
|
|
+ c.writer.Close()
|
|
|
+ c.writer = nil
|
|
|
+ }
|
|
|
+
|
|
|
+ if !isControl(messageType) && !isData(messageType) {
|
|
|
+ return errBadWriteOpCode
|
|
|
+ }
|
|
|
+
|
|
|
+ c.writeErrMu.Lock()
|
|
|
+ err := c.writeErr
|
|
|
+ c.writeErrMu.Unlock()
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+// NextWriter returns a writer for the next message to send. The writer's Close
|
|
|
+// method flushes the complete message to the network.
|
|
|
//
|
|
|
// There can be at most one open writer on a connection. NextWriter closes the
|
|
|
// previous writer if the application has not already done so.
|
|
|
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
|
|
- if c.writeErr != nil {
|
|
|
- return nil, c.writeErr
|
|
|
+ if err := c.prepWrite(messageType); err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- if c.writeFrameType != noFrame {
|
|
|
- if err := c.flushFrame(true, nil); err != nil {
|
|
|
+ mw := &messageWriter{
|
|
|
+ c: c,
|
|
|
+ frameType: messageType,
|
|
|
+ pos: maxFrameHeaderSize,
|
|
|
+ }
|
|
|
+ c.writer = mw
|
|
|
+ if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
|
|
+ w, err := c.newCompressionWriter(c.writer)
|
|
|
+ if err != nil {
|
|
|
+ c.writer = nil
|
|
|
return nil, err
|
|
|
}
|
|
|
+ mw.compress = true
|
|
|
+ c.writer = w
|
|
|
}
|
|
|
+ return c.writer, nil
|
|
|
+}
|
|
|
|
|
|
- if !isControl(messageType) && !isData(messageType) {
|
|
|
- return nil, errBadWriteOpCode
|
|
|
- }
|
|
|
+type messageWriter struct {
|
|
|
+ c *Conn
|
|
|
+ compress bool // whether next call to flushFrame should set RSV1
|
|
|
+ pos int // end of data in writeBuf.
|
|
|
+ frameType int // type of the current frame.
|
|
|
+ err error
|
|
|
+}
|
|
|
|
|
|
- c.writeFrameType = messageType
|
|
|
- return messageWriter{c, c.writeSeq}, nil
|
|
|
+func (w *messageWriter) fatal(err error) error {
|
|
|
+ if w.err != nil {
|
|
|
+ w.err = err
|
|
|
+ w.c.writer = nil
|
|
|
+ }
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
-func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
- length := c.writePos - maxFrameHeaderSize + len(extra)
|
|
|
+// flushFrame writes buffered data and extra as a frame to the network. The
|
|
|
+// final argument indicates that this is the last frame in the message.
|
|
|
+func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
|
|
+ c := w.c
|
|
|
+ length := w.pos - maxFrameHeaderSize + len(extra)
|
|
|
|
|
|
// Check for invalid control frames.
|
|
|
- if isControl(c.writeFrameType) &&
|
|
|
+ if isControl(w.frameType) &&
|
|
|
(!final || length > maxControlFramePayloadSize) {
|
|
|
- c.writeSeq++
|
|
|
- c.writeFrameType = noFrame
|
|
|
- c.writePos = maxFrameHeaderSize
|
|
|
- return errInvalidControlFrame
|
|
|
+ return w.fatal(errInvalidControlFrame)
|
|
|
}
|
|
|
|
|
|
- b0 := byte(c.writeFrameType)
|
|
|
+ b0 := byte(w.frameType)
|
|
|
if final {
|
|
|
b0 |= finalBit
|
|
|
}
|
|
|
+ if w.compress {
|
|
|
+ b0 |= rsv1Bit
|
|
|
+ }
|
|
|
+ w.compress = false
|
|
|
+
|
|
|
b1 := byte(0)
|
|
|
if !c.isServer {
|
|
|
b1 |= maskBit
|
|
|
@@ -435,10 +522,9 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
if !c.isServer {
|
|
|
key := newMaskKey()
|
|
|
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
|
|
|
- maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos])
|
|
|
+ maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
|
|
|
if len(extra) > 0 {
|
|
|
- c.writeErr = errors.New("websocket: internal error, extra used in client mode")
|
|
|
- return c.writeErr
|
|
|
+ return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -451,46 +537,35 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
}
|
|
|
c.isWriting = true
|
|
|
|
|
|
- c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
|
|
|
+ err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
|
|
|
|
|
|
if !c.isWriting {
|
|
|
panic("concurrent write to websocket connection")
|
|
|
}
|
|
|
c.isWriting = false
|
|
|
|
|
|
- // Setup for next frame.
|
|
|
- c.writePos = maxFrameHeaderSize
|
|
|
- c.writeFrameType = continuationFrame
|
|
|
- if final {
|
|
|
- c.writeSeq++
|
|
|
- c.writeFrameType = noFrame
|
|
|
+ if err != nil {
|
|
|
+ return w.fatal(err)
|
|
|
}
|
|
|
- return c.writeErr
|
|
|
-}
|
|
|
|
|
|
-type messageWriter struct {
|
|
|
- c *Conn
|
|
|
- seq int
|
|
|
-}
|
|
|
-
|
|
|
-func (w messageWriter) err() error {
|
|
|
- c := w.c
|
|
|
- if c.writeSeq != w.seq {
|
|
|
- return errWriteClosed
|
|
|
- }
|
|
|
- if c.writeErr != nil {
|
|
|
- return c.writeErr
|
|
|
+ if final {
|
|
|
+ c.writer = nil
|
|
|
+ return nil
|
|
|
}
|
|
|
+
|
|
|
+ // Setup for next frame.
|
|
|
+ w.pos = maxFrameHeaderSize
|
|
|
+ w.frameType = continuationFrame
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (w messageWriter) ncopy(max int) (int, error) {
|
|
|
- n := len(w.c.writeBuf) - w.c.writePos
|
|
|
+func (w *messageWriter) ncopy(max int) (int, error) {
|
|
|
+ n := len(w.c.writeBuf) - w.pos
|
|
|
if n <= 0 {
|
|
|
- if err := w.c.flushFrame(false, nil); err != nil {
|
|
|
+ if err := w.flushFrame(false, nil); err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- n = len(w.c.writeBuf) - w.c.writePos
|
|
|
+ n = len(w.c.writeBuf) - w.pos
|
|
|
}
|
|
|
if n > max {
|
|
|
n = max
|
|
|
@@ -498,14 +573,14 @@ func (w messageWriter) ncopy(max int) (int, error) {
|
|
|
return n, nil
|
|
|
}
|
|
|
|
|
|
-func (w messageWriter) write(final bool, p []byte) (int, error) {
|
|
|
- if err := w.err(); err != nil {
|
|
|
- return 0, err
|
|
|
+func (w *messageWriter) Write(p []byte) (int, error) {
|
|
|
+ if w.err != nil {
|
|
|
+ return 0, w.err
|
|
|
}
|
|
|
|
|
|
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
|
|
|
// Don't buffer large messages.
|
|
|
- err := w.c.flushFrame(final, p)
|
|
|
+ err := w.flushFrame(false, p)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
@@ -518,20 +593,16 @@ func (w messageWriter) write(final bool, p []byte) (int, error) {
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- copy(w.c.writeBuf[w.c.writePos:], p[:n])
|
|
|
- w.c.writePos += n
|
|
|
+ copy(w.c.writeBuf[w.pos:], p[:n])
|
|
|
+ w.pos += n
|
|
|
p = p[n:]
|
|
|
}
|
|
|
return nn, nil
|
|
|
}
|
|
|
|
|
|
-func (w messageWriter) Write(p []byte) (int, error) {
|
|
|
- return w.write(false, p)
|
|
|
-}
|
|
|
-
|
|
|
-func (w messageWriter) WriteString(p string) (int, error) {
|
|
|
- if err := w.err(); err != nil {
|
|
|
- return 0, err
|
|
|
+func (w *messageWriter) WriteString(p string) (int, error) {
|
|
|
+ if w.err != nil {
|
|
|
+ return 0, w.err
|
|
|
}
|
|
|
|
|
|
nn := len(p)
|
|
|
@@ -540,27 +611,27 @@ func (w messageWriter) WriteString(p string) (int, error) {
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- copy(w.c.writeBuf[w.c.writePos:], p[:n])
|
|
|
- w.c.writePos += n
|
|
|
+ copy(w.c.writeBuf[w.pos:], p[:n])
|
|
|
+ w.pos += n
|
|
|
p = p[n:]
|
|
|
}
|
|
|
return nn, nil
|
|
|
}
|
|
|
|
|
|
-func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
|
|
- if err := w.err(); err != nil {
|
|
|
- return 0, err
|
|
|
+func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
|
|
+ if w.err != nil {
|
|
|
+ return 0, w.err
|
|
|
}
|
|
|
for {
|
|
|
- if w.c.writePos == len(w.c.writeBuf) {
|
|
|
- err = w.c.flushFrame(false, nil)
|
|
|
+ if w.pos == len(w.c.writeBuf) {
|
|
|
+ err = w.flushFrame(false, nil)
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
|
}
|
|
|
var n int
|
|
|
- n, err = r.Read(w.c.writeBuf[w.c.writePos:])
|
|
|
- w.c.writePos += n
|
|
|
+ n, err = r.Read(w.c.writeBuf[w.pos:])
|
|
|
+ w.pos += n
|
|
|
nn += int64(n)
|
|
|
if err != nil {
|
|
|
if err == io.EOF {
|
|
|
@@ -572,30 +643,43 @@ func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
|
|
return nn, err
|
|
|
}
|
|
|
|
|
|
-func (w messageWriter) Close() error {
|
|
|
- if err := w.err(); err != nil {
|
|
|
+func (w *messageWriter) Close() error {
|
|
|
+ if w.err != nil {
|
|
|
+ return w.err
|
|
|
+ }
|
|
|
+ if err := w.flushFrame(true, nil); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- return w.c.flushFrame(true, nil)
|
|
|
+ w.err = errWriteClosed
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
// WriteMessage is a helper method for getting a writer using NextWriter,
|
|
|
// writing the message and closing the writer.
|
|
|
func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
|
|
- wr, err := c.NextWriter(messageType)
|
|
|
+
|
|
|
+ if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
|
|
|
+
|
|
|
+ // Fast path with no allocations and single frame.
|
|
|
+
|
|
|
+ if err := c.prepWrite(messageType); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
|
|
|
+ n := copy(c.writeBuf[mw.pos:], data)
|
|
|
+ mw.pos += n
|
|
|
+ data = data[n:]
|
|
|
+ return mw.flushFrame(true, data)
|
|
|
+ }
|
|
|
+
|
|
|
+ w, err := c.NextWriter(messageType)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- w := wr.(messageWriter)
|
|
|
- if _, err := w.write(true, data); err != nil {
|
|
|
+ if _, err = w.Write(data); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if c.writeSeq == w.seq {
|
|
|
- if err := c.flushFrame(true, nil); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- }
|
|
|
- return nil
|
|
|
+ return w.Close()
|
|
|
}
|
|
|
|
|
|
// SetWriteDeadline sets the write deadline on the underlying network
|
|
|
@@ -609,22 +693,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
|
|
|
|
|
|
// Read methods
|
|
|
|
|
|
-// readFull is like io.ReadFull except that io.EOF is never returned.
|
|
|
-func (c *Conn) readFull(p []byte) (err error) {
|
|
|
- var n int
|
|
|
- for n < len(p) && err == nil {
|
|
|
- var nn int
|
|
|
- nn, err = c.br.Read(p[n:])
|
|
|
- n += nn
|
|
|
- }
|
|
|
- if n == len(p) {
|
|
|
- err = nil
|
|
|
- } else if err == io.EOF {
|
|
|
- err = errUnexpectedEOF
|
|
|
- }
|
|
|
- return
|
|
|
-}
|
|
|
-
|
|
|
func (c *Conn) advanceFrame() (int, error) {
|
|
|
|
|
|
// 1. Skip remainder of previous frame.
|
|
|
@@ -637,19 +705,24 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
|
|
|
// 2. Read and parse first two bytes of frame header.
|
|
|
|
|
|
- var b [8]byte
|
|
|
- if err := c.readFull(b[:2]); err != nil {
|
|
|
+ p, err := c.read(2)
|
|
|
+ if err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
|
|
|
- final := b[0]&finalBit != 0
|
|
|
- frameType := int(b[0] & 0xf)
|
|
|
- reserved := int((b[0] >> 4) & 0x7)
|
|
|
- mask := b[1]&maskBit != 0
|
|
|
- c.readRemaining = int64(b[1] & 0x7f)
|
|
|
+ final := p[0]&finalBit != 0
|
|
|
+ frameType := int(p[0] & 0xf)
|
|
|
+ mask := p[1]&maskBit != 0
|
|
|
+ c.readRemaining = int64(p[1] & 0x7f)
|
|
|
+
|
|
|
+ c.readDecompress = false
|
|
|
+ if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
|
|
+ c.readDecompress = true
|
|
|
+ p[0] &^= rsv1Bit
|
|
|
+ }
|
|
|
|
|
|
- if reserved != 0 {
|
|
|
- return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
|
|
+ if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
|
|
|
+ return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
|
|
|
}
|
|
|
|
|
|
switch frameType {
|
|
|
@@ -678,15 +751,17 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
|
|
|
switch c.readRemaining {
|
|
|
case 126:
|
|
|
- if err := c.readFull(b[:2]); err != nil {
|
|
|
+ p, err := c.read(2)
|
|
|
+ if err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
- c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
|
|
|
+ c.readRemaining = int64(binary.BigEndian.Uint16(p))
|
|
|
case 127:
|
|
|
- if err := c.readFull(b[:8]); err != nil {
|
|
|
+ p, err := c.read(8)
|
|
|
+ if err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
- c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
|
|
|
+ c.readRemaining = int64(binary.BigEndian.Uint64(p))
|
|
|
}
|
|
|
|
|
|
// 4. Handle frame masking.
|
|
|
@@ -697,9 +772,11 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
|
|
|
if mask {
|
|
|
c.readMaskPos = 0
|
|
|
- if err := c.readFull(c.readMaskKey[:]); err != nil {
|
|
|
+ p, err := c.read(len(c.readMaskKey))
|
|
|
+ if err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
+ copy(c.readMaskKey[:], p)
|
|
|
}
|
|
|
|
|
|
// 5. For text and binary messages, enforce read limit and return.
|
|
|
@@ -719,9 +796,9 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
|
|
|
var payload []byte
|
|
|
if c.readRemaining > 0 {
|
|
|
- payload = make([]byte, c.readRemaining)
|
|
|
+ payload, err = c.read(int(c.readRemaining))
|
|
|
c.readRemaining = 0
|
|
|
- if err := c.readFull(payload); err != nil {
|
|
|
+ if err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
if c.isServer {
|
|
|
@@ -741,15 +818,21 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
case CloseMessage:
|
|
|
- echoMessage := []byte{}
|
|
|
closeCode := CloseNoStatusReceived
|
|
|
closeText := ""
|
|
|
if len(payload) >= 2 {
|
|
|
- echoMessage = payload[:2]
|
|
|
closeCode = int(binary.BigEndian.Uint16(payload))
|
|
|
+ if !isValidReceivedCloseCode(closeCode) {
|
|
|
+ return noFrame, c.handleProtocolError("invalid close code")
|
|
|
+ }
|
|
|
closeText = string(payload[2:])
|
|
|
+ if !utf8.ValidString(closeText) {
|
|
|
+ return noFrame, c.handleProtocolError("invalid utf8 payload in close frame")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if err := c.handleClose(closeCode, closeText); err != nil {
|
|
|
+ return noFrame, err
|
|
|
}
|
|
|
- c.WriteControl(CloseMessage, echoMessage, time.Now().Add(writeWait))
|
|
|
return noFrame, &CloseError{Code: closeCode, Text: closeText}
|
|
|
}
|
|
|
|
|
|
@@ -773,7 +856,7 @@ func (c *Conn) handleProtocolError(message string) error {
|
|
|
// this method return the same error.
|
|
|
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|
|
|
|
|
- c.readSeq++
|
|
|
+ c.messageReader = nil
|
|
|
c.readLength = 0
|
|
|
|
|
|
for c.readErr == nil {
|
|
|
@@ -783,7 +866,12 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|
|
break
|
|
|
}
|
|
|
if frameType == TextMessage || frameType == BinaryMessage {
|
|
|
- return frameType, messageReader{c, c.readSeq}, nil
|
|
|
+ c.messageReader = &messageReader{c}
|
|
|
+ var r io.Reader = c.messageReader
|
|
|
+ if c.readDecompress {
|
|
|
+ r = c.newDecompressionReader(r)
|
|
|
+ }
|
|
|
+ return frameType, r, nil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -798,48 +886,48 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|
|
return noFrame, nil, c.readErr
|
|
|
}
|
|
|
|
|
|
-type messageReader struct {
|
|
|
- c *Conn
|
|
|
- seq int
|
|
|
-}
|
|
|
-
|
|
|
-func (r messageReader) Read(b []byte) (int, error) {
|
|
|
+type messageReader struct{ c *Conn }
|
|
|
|
|
|
- if r.seq != r.c.readSeq {
|
|
|
+func (r *messageReader) Read(b []byte) (int, error) {
|
|
|
+ c := r.c
|
|
|
+ if c.messageReader != r {
|
|
|
return 0, io.EOF
|
|
|
}
|
|
|
|
|
|
- for r.c.readErr == nil {
|
|
|
+ for c.readErr == nil {
|
|
|
|
|
|
- if r.c.readRemaining > 0 {
|
|
|
- if int64(len(b)) > r.c.readRemaining {
|
|
|
- b = b[:r.c.readRemaining]
|
|
|
+ if c.readRemaining > 0 {
|
|
|
+ if int64(len(b)) > c.readRemaining {
|
|
|
+ b = b[:c.readRemaining]
|
|
|
+ }
|
|
|
+ n, err := c.br.Read(b)
|
|
|
+ c.readErr = hideTempErr(err)
|
|
|
+ if c.isServer {
|
|
|
+ c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
|
|
}
|
|
|
- n, err := r.c.br.Read(b)
|
|
|
- r.c.readErr = hideTempErr(err)
|
|
|
- if r.c.isServer {
|
|
|
- r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
|
|
|
+ c.readRemaining -= int64(n)
|
|
|
+ if c.readRemaining > 0 && c.readErr == io.EOF {
|
|
|
+ c.readErr = errUnexpectedEOF
|
|
|
}
|
|
|
- r.c.readRemaining -= int64(n)
|
|
|
- return n, r.c.readErr
|
|
|
+ return n, c.readErr
|
|
|
}
|
|
|
|
|
|
- if r.c.readFinal {
|
|
|
- r.c.readSeq++
|
|
|
+ if c.readFinal {
|
|
|
+ c.messageReader = nil
|
|
|
return 0, io.EOF
|
|
|
}
|
|
|
|
|
|
- frameType, err := r.c.advanceFrame()
|
|
|
+ frameType, err := c.advanceFrame()
|
|
|
switch {
|
|
|
case err != nil:
|
|
|
- r.c.readErr = hideTempErr(err)
|
|
|
+ c.readErr = hideTempErr(err)
|
|
|
case frameType == TextMessage || frameType == BinaryMessage:
|
|
|
- r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
|
|
+ c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- err := r.c.readErr
|
|
|
- if err == io.EOF && r.seq == r.c.readSeq {
|
|
|
+ err := c.readErr
|
|
|
+ if err == io.EOF && c.messageReader == r {
|
|
|
err = errUnexpectedEOF
|
|
|
}
|
|
|
return 0, err
|
|
|
@@ -872,6 +960,34 @@ func (c *Conn) SetReadLimit(limit int64) {
|
|
|
c.readLimit = limit
|
|
|
}
|
|
|
|
|
|
+// CloseHandler returns the current close handler
|
|
|
+func (c *Conn) CloseHandler() func(code int, text string) error {
|
|
|
+ return c.handleClose
|
|
|
+}
|
|
|
+
|
|
|
+// SetCloseHandler sets the handler for close messages received from the peer.
|
|
|
+// The code argument to h is the received close code or CloseNoStatusReceived
|
|
|
+// if the close message is empty. The default close handler sends a close frame
|
|
|
+// back to the peer.
|
|
|
+func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
|
|
|
+ if h == nil {
|
|
|
+ h = func(code int, text string) error {
|
|
|
+ message := []byte{}
|
|
|
+ if code != CloseNoStatusReceived {
|
|
|
+ message = FormatCloseMessage(code, "")
|
|
|
+ }
|
|
|
+ c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ c.handleClose = h
|
|
|
+}
|
|
|
+
|
|
|
+// PingHandler returns the current ping handler
|
|
|
+func (c *Conn) PingHandler() func(appData string) error {
|
|
|
+ return c.handlePing
|
|
|
+}
|
|
|
+
|
|
|
// SetPingHandler sets the handler for ping messages received from the peer.
|
|
|
// The appData argument to h is the PING frame application data. The default
|
|
|
// ping handler sends a pong to the peer.
|
|
|
@@ -890,6 +1006,11 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
|
|
|
c.handlePing = h
|
|
|
}
|
|
|
|
|
|
+// PongHandler returns the current pong handler
|
|
|
+func (c *Conn) PongHandler() func(appData string) error {
|
|
|
+ return c.handlePong
|
|
|
+}
|
|
|
+
|
|
|
// SetPongHandler sets the handler for pong messages received from the peer.
|
|
|
// The appData argument to h is the PONG frame application data. The default
|
|
|
// pong handler does nothing.
|
|
|
@@ -906,6 +1027,13 @@ func (c *Conn) UnderlyingConn() net.Conn {
|
|
|
return c.conn
|
|
|
}
|
|
|
|
|
|
+// EnableWriteCompression enables and disables write compression of
|
|
|
+// subsequent text and binary messages. This function is a noop if
|
|
|
+// compression was not negotiated with the peer.
|
|
|
+func (c *Conn) EnableWriteCompression(enable bool) {
|
|
|
+ c.enableWriteCompression = enable
|
|
|
+}
|
|
|
+
|
|
|
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
|
|
|
func FormatCloseMessage(closeCode int, text string) []byte {
|
|
|
buf := make([]byte, 2+len(text))
|