diff --git a/webclient.go b/webclient.go index df75622..f76cc9a 100644 --- a/webclient.go +++ b/webclient.go @@ -676,6 +676,7 @@ func startClient(conn *websocket.Conn) (err error) { defer func() { if isWSNormalError(err) { err = nil + c.close(nil) } else { m, e := errorToWSCloseMessage(err) if m != "" { @@ -685,13 +686,8 @@ func startClient(conn *websocket.Conn) (err error) { Value: m, }) } - select { - case c.writeCh <- closeMessage{e}: - case <-c.writerDone: - } + c.close(e) } - close(c.writeCh) - c.writeCh = nil }() c.writerDone = make(chan struct{}) @@ -1153,10 +1149,13 @@ func clientWriter(conn *websocket.Conn, ch <-chan interface{}, done chan<- struc return } case closeMessage: - err := conn.WriteMessage(websocket.CloseMessage, m.data) - if err != nil { - return + if m.data != nil { + conn.WriteMessage( + websocket.CloseMessage, + m.data, + ) } + return default: log.Printf("clientWiter: unexpected message %T", m) return @@ -1185,6 +1184,15 @@ func (c *webClient) write(m clientMessage) error { } } +func (c *webClient) close(data []byte) error { + select { + case c.writeCh <- closeMessage{data}: + return nil + case <-c.writerDone: + return ErrWriterDead + } +} + func (c *webClient) error(err error) error { switch e := err.(type) { case userError: