1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-12-22 15:25:48 +01:00
galene/rtpconn/webclient.go
Juliusz Chroboczek ef0201c94d Replace RTCPFeedback in down tracks.
We used to copy the RTCPFeeback field from the up track.  It is
more correct to regenerate it with the exact feedback types
that we expect.
2024-11-30 17:55:24 +01:00

2256 lines
47 KiB
Go

package rtpconn
import (
crand "crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"os"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/pion/webrtc/v3"
"github.com/jech/galene/conn"
"github.com/jech/galene/diskwriter"
"github.com/jech/galene/estimator"
"github.com/jech/galene/group"
"github.com/jech/galene/ice"
"github.com/jech/galene/token"
"github.com/jech/galene/unbounded"
)
func errorToWSCloseMessage(id string, err error) (*clientMessage, []byte) {
var code int
var m *clientMessage
var text string
switch e := err.(type) {
case *websocket.CloseError:
code = websocket.CloseNormalClosure
case group.ProtocolError:
code = websocket.CloseProtocolError
m = &clientMessage{
Type: "usermessage",
Kind: "error",
Dest: id,
Privileged: true,
Value: e.Error(),
}
text = e.Error()
case group.UserError, group.KickError:
code = websocket.CloseNormalClosure
m = errorMessage(id, err)
text = e.Error()
default:
code = websocket.CloseInternalServerErr
}
return m, websocket.FormatCloseMessage(code, text)
}
func isWSNormalError(err error) bool {
return websocket.IsCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway)
}
type webClient struct {
group *group.Group
addr net.Addr
id string
username string
permissions []string
data map[string]interface{}
requested map[string][]string
done chan struct{}
writeCh chan interface{}
writerDone chan struct{}
actions *unbounded.Channel[any]
mu sync.Mutex
down map[string]*rtpDownConnection
up map[string]*rtpUpConnection
}
func (c *webClient) Group() *group.Group {
return c.group
}
func (c *webClient) Addr() net.Addr {
return c.addr
}
func (c *webClient) Id() string {
return c.id
}
func (c *webClient) Username() string {
return c.username
}
func (c *webClient) SetUsername(username string) {
c.username = username
}
func (c *webClient) Permissions() []string {
return c.permissions
}
func (c *webClient) Data() map[string]interface{} {
return c.data
}
func (c *webClient) SetPermissions(perms []string) {
c.permissions = perms
}
func (c *webClient) PushClient(group, kind, id string, username string, perms []string, data map[string]interface{}) error {
c.action(pushClientAction{
group, kind, id, username, perms, data,
})
return nil
}
type clientMessage struct {
Type string `json:"type"`
Version []string `json:"version,omitempty"`
Kind string `json:"kind,omitempty"`
Error string `json:"error,omitempty"`
Id string `json:"id,omitempty"`
Replace string `json:"replace,omitempty"`
Source string `json:"source,omitempty"`
Dest string `json:"dest,omitempty"`
Username *string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Token string `json:"token,omitempty"`
Privileged bool `json:"privileged,omitempty"`
Permissions []string `json:"permissions,omitempty"`
Status *group.Status `json:"status,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
Group string `json:"group,omitempty"`
Value interface{} `json:"value,omitempty"`
NoEcho bool `json:"noecho,omitempty"`
Time string `json:"time,omitempty"`
SDP string `json:"sdp,omitempty"`
Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"`
Label string `json:"label,omitempty"`
Request interface{} `json:"request,omitempty"`
RTCConfiguration *webrtc.Configuration `json:"rtcConfiguration,omitempty"`
}
type closeMessage struct {
data []byte
}
func getUpConn(c *webClient, id string) *rtpUpConnection {
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
return nil
}
return c.up[id]
}
func getUpConns(c *webClient) []*rtpUpConnection {
c.mu.Lock()
defer c.mu.Unlock()
up := make([]*rtpUpConnection, 0, len(c.up))
for _, u := range c.up {
up = append(up, u)
}
return up
}
func addUpConn(c *webClient, id, label string, offer string) (*rtpUpConnection, bool, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
c.up = make(map[string]*rtpUpConnection)
}
if c.down != nil && c.down[id] != nil {
return nil, false, errors.New("adding duplicate connection")
}
old := c.up[id]
if old != nil {
return old, false, nil
}
conn, err := newUpConn(c, id, label, offer)
if err != nil {
return nil, false, err
}
c.up[id] = conn
conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
sendICE(c, id, candidate)
})
conn.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
if state == webrtc.ICEConnectionStateFailed {
c.action(connectionFailedAction{id: id})
}
})
return conn, true, nil
}
var ErrUserMismatch = errors.New("user id mismatch")
// delUpConn deletes an up connection. If push is closed, the close is
// pushed to all corresponding down connections.
func delUpConn(c *webClient, id string, userId string, push bool) error {
c.mu.Lock()
if c.up == nil {
c.mu.Unlock()
return os.ErrNotExist
}
conn := c.up[id]
if conn == nil {
c.mu.Unlock()
return os.ErrNotExist
}
if userId != "" {
id, _ := conn.User()
if id != userId {
c.mu.Unlock()
return ErrUserMismatch
}
}
replace := conn.getReplace(false)
delete(c.up, id)
g := c.group
c.mu.Unlock()
conn.mu.Lock()
conn.closed = true
conn.mu.Unlock()
conn.pc.Close()
if push && g != nil {
for _, c := range g.GetClients(c) {
err := c.PushConn(g, id, nil, nil, replace)
if err != nil {
log.Printf("PushConn: %v", err)
}
}
}
return nil
}
func getDownConn(c *webClient, id string) *rtpDownConnection {
if c.down == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
conn := c.down[id]
if conn == nil {
return nil
}
return conn
}
func getConn(c *webClient, id string) iceConnection {
up := getUpConn(c, id)
if up != nil {
return up
}
down := getDownConn(c, id)
if down != nil {
return down
}
return nil
}
func addDownConn(c *webClient, remote conn.Up) (*rtpDownConnection, bool, error) {
id := remote.Id()
c.mu.Lock()
defer c.mu.Unlock()
if c.up != nil && c.up[id] != nil {
return nil, false, errors.New("adding duplicate connection")
}
if c.down == nil {
c.down = make(map[string]*rtpDownConnection)
}
if down := c.down[id]; down != nil {
return down, false, nil
}
down, err := newDownConn(c, id, remote)
if err != nil {
return nil, false, err
}
down.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
sendICE(c, down.id, candidate)
})
down.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
if state == webrtc.ICEConnectionStateFailed {
c.action(connectionFailedAction{id: down.id})
}
})
err = remote.AddLocal(down)
if err != nil {
down.pc.Close()
return nil, false, err
}
c.down[down.id] = down
go rtcpDownSender(down)
return down, true, nil
}
func delDownConn(c *webClient, id string) error {
conn := delDownConnHelper(c, id)
if conn != nil {
conn.pc.Close()
return nil
}
return os.ErrNotExist
}
func delDownConnHelper(c *webClient, id string) *rtpDownConnection {
c.mu.Lock()
defer c.mu.Unlock()
if c.down == nil {
return nil
}
conn := c.down[id]
if conn == nil {
return nil
}
conn.remote.DelLocal(conn)
for _, track := range conn.tracks {
// we only insert the track after we get an answer, so
// ignore errors here.
track.remote.DelLocal(track)
}
delete(c.down, id)
return conn
}
var errUnexpectedTrackType = errors.New("unexpected track type, this shouldn't happen")
func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack) error {
for _, t := range conn.tracks {
tt, ok := t.remote.(*rtpUpTrack)
if !ok {
return errUnexpectedTrackType
}
if tt == remoteTrack {
return os.ErrExist
}
}
id := remoteTrack.track.ID()
if id == "" {
log.Println("Got track with empty id")
id = remoteTrack.track.RID()
}
if id == "" {
id = remoteTrack.track.Kind().String()
}
msid := remoteTrack.track.StreamID()
if msid == "" || msid == "-" {
log.Println("Got track with empty msid")
msid = remoteTrack.conn.Label()
}
if msid == "" {
msid = "dummy"
}
// replace the RTCP feedback types with the ones we understand
remoteCodec := remoteTrack.Codec()
if strings.HasPrefix(strings.ToLower(remoteCodec.MimeType), "video/") {
remoteCodec.RTCPFeedback = group.VideoRTCPFeedback
} else {
remoteCodec.RTCPFeedback = group.AudioRTCPFeedback
}
local, err := webrtc.NewTrackLocalStaticRTP(
remoteCodec, id, msid,
)
if err != nil {
return err
}
transceiver, err := conn.pc.AddTransceiverFromTrack(local,
webrtc.RTPTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionSendonly,
},
)
if err != nil {
return err
}
codec := local.Codec()
ptype, err := group.CodecPayloadType(local.Codec())
if err != nil {
log.Printf("Couldn't determine ptype for codec %v: %v",
codec.MimeType, err)
} else {
err := transceiver.SetCodecPreferences(
[]webrtc.RTPCodecParameters{
{
RTPCodecCapability: codec,
PayloadType: ptype,
},
},
)
if err != nil {
log.Printf("Couldn't set ptype for codec %v: %v",
codec.MimeType, err)
}
}
parms := transceiver.Sender().GetParameters()
if len(parms.Encodings) != 1 {
return errors.New("got multiple encodings")
}
track := &rtpDownTrack{
track: local,
sender: transceiver.Sender(),
ssrc: parms.Encodings[0].SSRC,
conn: conn,
remote: remoteTrack,
maxBitrate: new(bitrate),
maxREMBBitrate: new(bitrate),
stats: new(receiverStats),
rate: estimator.New(time.Second),
atomics: &downTrackAtomics{},
}
conn.tracks = append(conn.tracks, track)
go rtcpDownListener(track)
return nil
}
func delDownTrackUnlocked(conn *rtpDownConnection, track *rtpDownTrack) error {
for i := range conn.tracks {
if conn.tracks[i] == track {
track.remote.DelLocal(track)
conn.tracks =
append(conn.tracks[:i], conn.tracks[i+1:]...)
return conn.pc.RemoveTrack(track.sender)
}
}
return os.ErrNotExist
}
func replaceTracks(conn *rtpDownConnection, remote []conn.UpTrack, limitSid bool) (bool, error) {
conn.mu.Lock()
defer conn.mu.Unlock()
var add []*rtpUpTrack
var del []*rtpDownTrack
outer:
for _, rtrack := range remote {
rt, ok := rtrack.(*rtpUpTrack)
if !ok {
return false, errUnexpectedTrackType
}
for _, track := range conn.tracks {
rt2, ok := track.remote.(*rtpUpTrack)
if !ok {
return false, errUnexpectedTrackType
}
if rt == rt2 {
continue outer
}
}
add = append(add, rt)
}
outer2:
for _, track := range conn.tracks {
rt, ok := track.remote.(*rtpUpTrack)
if !ok {
return false, errUnexpectedTrackType
}
for _, rtrack := range remote {
rt2, ok := rtrack.(*rtpUpTrack)
if !ok {
return false, errUnexpectedTrackType
}
if rt == rt2 {
continue outer2
}
}
del = append(del, track)
}
defer func() {
for _, t := range conn.tracks {
layer := t.getLayerInfo()
layer.limitSid = limitSid
if limitSid {
layer.wantedSid = 0
}
t.setLayerInfo(layer)
}
}()
if len(del) == 0 && len(add) == 0 {
return false, nil
}
for _, t := range del {
err := delDownTrackUnlocked(conn, t)
if err != nil {
return false, err
}
}
for _, rt := range add {
err := addDownTrackUnlocked(conn, rt)
if err != nil {
return false, err
}
}
return true, nil
}
func negotiate(c *webClient, down *rtpDownConnection, restartIce bool, replace string) error {
if down.pc.SignalingState() == webrtc.SignalingStateHaveLocalOffer {
// avoid sending multiple offers back-to-back
if restartIce {
down.negotiationNeeded = negotiationRestartIce
} else if down.negotiationNeeded == negotiationUnneeded {
down.negotiationNeeded = negotiationNeeded
}
return nil
}
down.negotiationNeeded = negotiationUnneeded
options := webrtc.OfferOptions{ICERestart: restartIce}
offer, err := down.pc.CreateOffer(&options)
if err != nil {
return err
}
err = down.pc.SetLocalDescription(offer)
if err != nil {
return err
}
source, username := down.remote.User()
return c.write(clientMessage{
Type: "offer",
Id: down.id,
Label: down.remote.Label(),
Replace: replace,
Source: source,
Username: &username,
SDP: down.pc.LocalDescription().SDP,
})
}
func sendICE(c *webClient, id string, candidate *webrtc.ICECandidate) error {
if candidate == nil {
return nil
}
cand := candidate.ToJSON()
return c.write(clientMessage{
Type: "ice",
Id: id,
Candidate: &cand,
})
}
func gotOffer(c *webClient, id, label string, sdp string, replace string) error {
up, _, err := addUpConn(c, id, label, sdp)
if err != nil {
return err
}
if replace != "" {
up.replace = replace
delUpConn(c, replace, c.Id(), false)
}
err = up.pc.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer,
SDP: sdp,
})
if err != nil {
return err
}
answer, err := up.pc.CreateAnswer(nil)
if err != nil {
return err
}
err = up.pc.SetLocalDescription(answer)
if err != nil {
return err
}
err = up.flushICECandidates()
if err != nil {
log.Printf("ICE: %v", err)
}
return c.write(clientMessage{
Type: "answer",
Id: id,
SDP: up.pc.LocalDescription().SDP,
})
}
var ErrUnknownId = errors.New("unknown id")
func gotAnswer(c *webClient, id string, sdp string) error {
down := getDownConn(c, id)
if down == nil {
return ErrUnknownId
}
err := down.pc.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer,
SDP: sdp,
})
if err != nil {
return err
}
err = down.flushICECandidates()
if err != nil {
log.Printf("ICE: %v", err)
}
add := func() {
down.pc.OnConnectionStateChange(nil)
for _, t := range down.tracks {
err := t.remote.AddLocal(t)
if err != nil && err != os.ErrClosed {
log.Printf("Add track: %v", err)
}
}
}
down.pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateConnected {
add()
}
})
if down.pc.ConnectionState() == webrtc.PeerConnectionStateConnected {
add()
}
return nil
}
func gotICE(c *webClient, candidate *webrtc.ICECandidateInit, id string) error {
conn := getConn(c, id)
if conn == nil {
return errors.New("unknown id in ICE")
}
return conn.addICECandidate(candidate)
}
var errBadType = errors.New("bad type")
func toStringArray(r interface{}) ([]string, error) {
if r == nil {
return nil, nil
}
rr, ok := r.([]interface{})
if !ok {
return nil, errBadType
}
if rr == nil {
return nil, nil
}
rrr := make([]string, len(rr))
for i, s := range rr {
rrr[i], ok = s.(string)
if !ok {
return nil, errBadType
}
}
return rrr, nil
}
func parseRequested(r interface{}) (map[string][]string, error) {
if r == nil {
return nil, nil
}
rr, ok := r.(map[string]interface{})
if !ok {
return nil, errBadType
}
if rr == nil {
return nil, nil
}
rrr := make(map[string][]string)
for k, v := range rr {
vv, err := toStringArray(v)
if err != nil {
return nil, err
}
rrr[k] = vv
}
return rrr, nil
}
func (c *webClient) setRequested(requested map[string][]string) error {
if c.group == nil {
return errors.New("attempted to request with no group joined")
}
c.requested = requested
requestConns(c, c.group, "")
return nil
}
func (c *webClient) setRequestedStream(down *rtpDownConnection, requested []string) error {
var remoteClient group.Client
remote, ok := down.remote.(*rtpUpConnection)
if ok {
remoteClient = remote.client
}
down.requested = requested
return remoteClient.RequestConns(c, c.group, remote.id)
}
func (c *webClient) RequestConns(target group.Client, g *group.Group, id string) error {
c.action(requestConnsAction{g, target, id})
return nil
}
func requestConns(target group.Client, g *group.Group, id string) {
clients := g.GetClients(target)
for _, c := range clients {
c.RequestConns(target, g, id)
}
}
func requestedTracks(c *webClient, requested []string, tracks []conn.UpTrack) ([]conn.UpTrack, bool) {
if len(requested) == 0 {
return nil, false
}
var audio, video, videoLow bool
for _, s := range requested {
switch s {
case "audio":
audio = true
case "video":
video = true
case "video-low":
videoLow = true
default:
log.Printf("client requested unknown value %v", s)
}
}
find := func(kind webrtc.RTPCodecType, last bool) (conn.UpTrack, int) {
var track conn.UpTrack
count := 0
for _, t := range tracks {
if t.Kind() != kind {
continue
}
track = t
count++
if !last {
break
}
}
return track, count
}
var ts []conn.UpTrack
limitSid := false
if audio {
t, _ := find(webrtc.RTPCodecTypeAudio, false)
if t != nil {
ts = append(ts, t)
}
}
if video {
t, _ := find(webrtc.RTPCodecTypeVideo, false)
if t != nil {
ts = append(ts, t)
}
} else if videoLow {
t, count := find(webrtc.RTPCodecTypeVideo, true)
if t != nil {
ts = append(ts, t)
}
if count < 2 {
limitSid = true
}
}
return ts, limitSid
}
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
c.action(pushConnAction{g, id, up, tracks, replace})
return nil
}
func readMessage(conn *websocket.Conn, m *clientMessage) error {
err := conn.SetReadDeadline(time.Now().Add(15 * time.Second))
if err != nil {
return err
}
defer conn.SetReadDeadline(time.Time{})
return conn.ReadJSON(&m)
}
const protocolVersion = "2"
func StartClient(conn *websocket.Conn, addr net.Addr) (err error) {
var m clientMessage
err = readMessage(conn, &m)
if err != nil {
conn.Close()
return
}
if m.Type != "handshake" {
conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(
websocket.CloseProtocolError,
"you must handshake first",
),
)
conn.Close()
err = group.ProtocolError("client didn't handshake")
return
}
versionError := true
if m.Version != nil {
for _, v := range m.Version {
if v == protocolVersion {
versionError = false
}
}
}
c := &webClient{
addr: addr,
id: m.Id,
actions: unbounded.New[any](),
done: make(chan struct{}),
}
defer close(c.done)
c.writeCh = make(chan interface{}, 100)
c.writerDone = make(chan struct{})
go clientWriter(conn, c.writeCh, c.writerDone)
defer func() {
m, e := errorToWSCloseMessage(c.id, err)
if isWSNormalError(err) {
err = nil
} else if _, ok := err.(group.KickError); ok {
err = nil
}
if m != nil {
c.write(*m)
}
c.close(e)
}()
return clientLoop(c, conn, versionError)
}
type pushConnAction struct {
group *group.Group
id string
conn conn.Up
tracks []conn.UpTrack
replace string
}
type requestConnsAction struct {
group *group.Group
target group.Client
id string
}
type connectionFailedAction struct {
id string
}
type pushClientAction struct {
group string
kind string
id string
username string
permissions []string
data map[string]interface{}
}
type permissionsChangedAction struct{}
type joinedAction struct {
group string
kind string
}
type kickAction struct {
id string
username *string
message string
}
var errEmptyId = group.ProtocolError("empty id")
func member(v string, l []string) bool {
for _, w := range l {
if v == w {
return true
}
}
return false
}
func remove(v string, l []string) []string {
for i, w := range l {
if v == w {
l = append(l[:i], l[i+1:]...)
return l
}
}
return l
}
func addnew(v string, l []string) []string {
if member(v, l) {
return l
}
l = append(l, v)
return l
}
func clientLoop(c *webClient, ws *websocket.Conn, versionError bool) error {
read := make(chan interface{}, 1)
go clientReader(ws, read, c.done)
defer leaveGroup(c)
readTime := time.Now()
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
err := c.write(clientMessage{
Type: "handshake",
Version: []string{protocolVersion},
})
if err != nil {
return err
}
if versionError {
c.write(clientMessage{
Type: "usermessage",
Kind: "warning",
Dest: c.id,
Privileged: true,
Value: "This client is using an unknown protocol version.\n" +
"Perhaps it needs upgrading?\n" +
"Trying to continue, things may break.",
})
}
for {
select {
case m, ok := <-read:
if !ok {
return errors.New("reader died")
}
switch m := m.(type) {
case clientMessage:
readTime = time.Now()
err := handleClientMessage(c, m)
if err != nil {
return err
}
case error:
return m
}
case <-c.actions.Ch:
actions := c.actions.Get()
for _, a := range actions {
err := handleAction(c, a)
if err != nil {
return err
}
}
case <-ticker.C:
if time.Since(readTime) > 45*time.Second {
return errors.New("client is dead")
}
// Some reverse proxies timeout connexions at 60
// seconds, make sure we generate some activity
if time.Since(readTime) > 20*time.Second {
err := c.write(clientMessage{
Type: "ping",
})
if err != nil {
return err
}
}
}
}
}
func pushDownConn(c *webClient, id string, up conn.Up, tracks []conn.UpTrack, replace string) error {
var requested []conn.UpTrack
limitSid := false
if up != nil {
var old *rtpDownConnection
if replace != "" {
old = getDownConn(c, replace)
} else {
old = getDownConn(c, up.Id())
}
var req []string
if old != nil {
req = old.requested
}
if req == nil {
var ok bool
req, ok = c.requested[up.Label()]
if !ok {
req = c.requested[""]
}
}
requested, limitSid = requestedTracks(c, req, tracks)
}
if replace != "" {
err := delDownConn(c, replace)
if err != nil {
log.Printf("Replace: %v", err)
}
}
// closes over replace, which will be modified below
defer func() {
if replace != "" {
closeDownConn(c, replace, "")
}
}()
if len(requested) == 0 {
closeDownConn(c, id, "")
return nil
}
down, _, err := addDownConn(c, up)
if err != nil {
if errors.Is(err, os.ErrClosed) {
return nil
}
return err
}
done, err := replaceTracks(down, requested, limitSid)
if err != nil || !done {
return err
}
err = negotiate(c, down, false, replace)
if err != nil {
log.Printf("Negotiation failed: %v", err)
closeDownConn(c, down.id, err.Error())
return err
}
replace = ""
return nil
}
func handleAction(c *webClient, a any) error {
switch a := a.(type) {
case pushConnAction:
if c.group == nil || c.group != a.group {
log.Printf("Got connectsions for wrong group")
return nil
}
return pushDownConn(c, a.id, a.conn, a.tracks, a.replace)
case requestConnsAction:
g := c.group
if g == nil || a.group != g {
log.Printf("Misdirected pushConns")
return nil
}
for _, u := range c.up {
if a.id != "" && a.id != u.id {
continue
}
tracks := u.getTracks()
replace := u.getReplace(false)
ts := make([]conn.UpTrack, len(tracks))
for i, t := range tracks {
ts[i] = t
}
err := a.target.PushConn(g, u.id, u, ts, replace)
if err != nil {
log.Printf("PushConn: %v", err)
}
}
case connectionFailedAction:
if down := getDownConn(c, a.id); down != nil {
err := negotiate(c, down, true, "")
if err != nil {
return err
}
tracks := make(
[]conn.UpTrack, len(down.tracks),
)
for i, t := range down.tracks {
tracks[i] = t.remote
}
c.PushConn(
c.group,
down.remote.Id(), down.remote,
tracks, "",
)
} else if up := getUpConn(c, a.id); up != nil {
c.write(clientMessage{
Type: "renegotiate",
Id: a.id,
})
} else {
log.Printf("Attempting to renegotiate " +
"unknown connection")
}
case pushClientAction:
if a.group != c.group.Name() {
log.Printf("got client for wrong group")
return nil
}
perms := append([]string(nil), a.permissions...)
username := a.username
return c.write(clientMessage{
Type: "user",
Kind: a.kind,
Id: a.id,
Username: &username,
Permissions: perms,
Data: a.data,
})
case joinedAction:
var status *group.Status
var data map[string]interface{}
var g *group.Group
if a.group != "" {
g = group.Get(a.group)
if g != nil {
s := g.Status(true, nil)
status = &s
data = g.Data()
}
}
perms := append([]string(nil), c.permissions...)
username := c.username
err := c.write(clientMessage{
Type: "joined",
Kind: a.kind,
Group: a.group,
Username: &username,
Permissions: perms,
Status: status,
Data: data,
RTCConfiguration: ice.ICEConfiguration(),
})
if err != nil {
return err
}
if a.kind == "join" {
if g == nil {
log.Println("g is null when joining" +
"this shouldn't happen")
return nil
}
h := g.GetChatHistory()
for _, m := range h {
err := c.write(clientMessage{
Type: "chathistory",
Id: m.Id,
Source: m.Source,
Username: m.User,
Time: m.Time.Format(time.RFC3339),
Value: m.Value,
Kind: m.Kind,
})
if err != nil {
return err
}
}
}
case permissionsChangedAction:
g := c.Group()
if g == nil {
return errors.New("Permissions changed in no group")
}
perms := append([]string(nil), c.permissions...)
status := g.Status(true, nil)
username := c.username
c.write(clientMessage{
Type: "joined",
Kind: "change",
Group: g.Name(),
Username: &username,
Permissions: perms,
Status: &status,
RTCConfiguration: ice.ICEConfiguration(),
})
if !member("present", c.permissions) {
up := getUpConns(c)
for _, u := range up {
err := delUpConn(
c, u.id, c.id, true,
)
if err == nil {
failUpConnection(
c, u.id,
"permission denied",
)
}
}
}
id := c.Id()
user := c.Username()
d := c.Data()
clients := g.GetClients(nil)
go func(clients []group.Client) {
for _, cc := range clients {
cc.PushClient(
g.Name(), "change", id, user, perms, d,
)
}
}(clients)
case kickAction:
return group.KickError{
a.id, a.username, a.message,
}
default:
log.Printf("unexpected action %T", a)
return errors.New("unexpected action")
}
return nil
}
func failUpConnection(c *webClient, id string, message string) error {
if id != "" {
err := c.write(clientMessage{
Type: "abort",
Id: id,
})
if err != nil {
return err
}
}
if message != "" {
err := c.error(group.UserError(message))
if err != nil {
return err
}
}
return nil
}
func leaveGroup(c *webClient) {
if c.group == nil {
return
}
if c.up != nil {
for id := range c.up {
delUpConn(c, id, c.id, true)
}
}
if c.down != nil {
for id := range c.down {
delDownConn(c, id)
}
}
group.DelClient(c)
c.permissions = nil
c.data = nil
c.requested = make(map[string][]string)
c.group = nil
}
func closeDownConn(c *webClient, id string, message string) error {
err := delDownConn(c, id)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Printf("Close down connection: %v", err)
}
err = c.write(clientMessage{
Type: "close",
Id: id,
})
if err != nil {
return err
}
if message != "" {
err := c.error(group.UserError(message))
if err != nil {
return err
}
}
return nil
}
func setPermissions(g *group.Group, id string, perm string) error {
client := g.GetClient(id)
if client == nil {
return group.UserError("no such user")
}
c, ok := client.(*webClient)
if !ok {
return group.UserError("this is not a real user")
}
switch perm {
case "op":
c.permissions = addnew("op", c.permissions)
if g.Description().AllowRecording {
c.permissions = addnew("record", c.permissions)
}
case "unop":
c.permissions = remove("op", c.permissions)
c.permissions = remove("record", c.permissions)
case "present":
c.permissions = addnew("present", c.permissions)
case "unpresent":
c.permissions = remove("present", c.permissions)
case "shutup":
c.permissions = remove("message", c.permissions)
case "unshutup":
c.permissions = addnew("message", c.permissions)
default:
return group.UserError("unknown permission")
}
c.action(permissionsChangedAction{})
return nil
}
func (c *webClient) Kick(id string, user *string, message string) error {
c.action(kickAction{id, user, message})
return nil
}
func (c *webClient) Joined(group, kind string) error {
c.action(joinedAction{group, kind})
return nil
}
func kickClient(g *group.Group, id string, user *string, dest string, message string) error {
client := g.GetClient(dest)
if client == nil {
return group.UserError("no such user")
}
return client.Kick(id, user, message)
}
func handleClientMessage(c *webClient, m clientMessage) error {
if m.Source != "" {
if m.Source != c.Id() {
return group.ProtocolError("spoofed client id")
}
}
if m.Type != "join" {
if m.Username != nil {
if *m.Username != c.Username() {
return group.ProtocolError("spoofed username")
}
}
}
switch m.Type {
case "join":
if m.Kind == "leave" {
if c.group == nil || c.group.Name() != m.Group {
return group.UserError("you are not joined")
}
leaveGroup(c)
return nil
}
if m.Kind != "join" {
return group.ProtocolError("unknown kind")
}
if c.group != nil {
return group.ProtocolError(
"cannot join multiple groups",
)
}
c.data = m.Data
g, err := group.AddClient(m.Group, c,
group.ClientCredentials{
Username: m.Username,
Password: m.Password,
Token: m.Token,
},
)
if err != nil {
var e, s string
var autherr *group.NotAuthorisedError
if errors.Is(err, token.ErrUsernameRequired) {
s = err.Error()
e = "need-username"
} else if errors.Is(err, group.ErrDuplicateUsername) {
s = err.Error()
e = "duplicate-username"
} else if errors.As(err, &autherr) {
s = "not authorised"
time.Sleep(200 * time.Millisecond)
log.Printf("Join group: %v", err)
} else if errors.Is(err, os.ErrNotExist) {
s = "group does not exist"
} else if _, ok := err.(group.UserError); ok {
s = err.Error()
} else {
s = "internal server error"
log.Printf("Join group: %v", err)
}
username := c.username
return c.write(clientMessage{
Type: "joined",
Kind: "fail",
Error: e,
Group: m.Group,
Username: &username,
Value: s,
})
}
if redirect := g.Description().Redirect; redirect != "" {
// We normally redirect at the HTTP level, but the group
// description could have been edited in the meantime.
username := c.username
return c.write(clientMessage{
Type: "joined",
Kind: "redirect",
Group: m.Group,
Username: &username,
Value: redirect,
})
}
c.group = g
case "request":
requested, err := parseRequested(m.Request)
if err != nil {
return err
}
return c.setRequested(requested)
case "requestStream":
down := getDownConn(c, m.Id)
if down == nil {
return ErrUnknownId
}
requested, err := toStringArray(m.Request)
if err != nil {
return err
}
c.setRequestedStream(down, requested)
case "offer":
if m.Id == "" {
return errEmptyId
}
if !member("present", c.permissions) {
if m.Replace != "" {
delUpConn(c, m.Replace, c.id, true)
}
c.write(clientMessage{
Type: "abort",
Id: m.Id,
})
return c.error(group.UserError("not authorised"))
}
err := gotOffer(c, m.Id, m.Label, m.SDP, m.Replace)
if err != nil {
log.Printf("gotOffer: %v", err)
return failUpConnection(c, m.Id, err.Error())
}
case "answer":
if m.Id == "" {
return errEmptyId
}
err := gotAnswer(c, m.Id, m.SDP)
if err != nil {
log.Printf("gotAnswer: %v", err)
message := ""
if err != ErrUnknownId {
message = err.Error()
}
return closeDownConn(c, m.Id, message)
}
down := getDownConn(c, m.Id)
if down == nil {
return ErrUnknownId
}
if down.negotiationNeeded > negotiationUnneeded {
err := negotiate(
c, down,
down.negotiationNeeded == negotiationRestartIce,
"",
)
if err != nil {
return closeDownConn(c, m.Id, err.Error())
}
}
case "renegotiate":
if m.Id == "" {
return errEmptyId
}
down := getDownConn(c, m.Id)
if down != nil {
err := negotiate(c, down, true, "")
if err != nil {
return closeDownConn(c, m.Id, err.Error())
}
} else {
log.Printf("Trying to renegotiate unknown connection")
}
case "close":
if m.Id == "" {
return errEmptyId
}
err := delUpConn(c, m.Id, c.id, true)
if err != nil {
log.Printf("Deleting up connection %v: %v",
m.Id, err)
return nil
}
case "abort":
if m.Id == "" {
return errEmptyId
}
return closeDownConn(c, m.Id, "")
case "ice":
if m.Id == "" {
return errEmptyId
}
if m.Candidate == nil {
return group.ProtocolError("null candidate")
}
err := gotICE(c, m.Candidate, m.Id)
if err != nil {
log.Printf("ICE: %v", err)
}
case "chat", "usermessage":
g := c.group
if g == nil {
return c.error(group.UserError("join a group first"))
}
required := "message"
if m.Type == "chat" && m.Kind == "caption" {
required = "caption"
}
if !member(required, c.permissions) {
return c.error(group.UserError("not authorised"))
}
id := m.Id
if m.Type == "chat" && m.Dest == "" && id == "" {
buf := make([]byte, 8)
crand.Read(buf)
id = base64.RawURLEncoding.EncodeToString(buf)
}
now := time.Now()
if m.Type == "chat" {
if m.Dest == "" {
g.AddToChatHistory(
id, m.Source, m.Username,
now, m.Kind, m.Value,
)
}
}
mm := clientMessage{
Type: m.Type,
Id: id,
Source: m.Source,
Dest: m.Dest,
Username: m.Username,
Privileged: member("op", c.permissions),
Time: now.Format(time.RFC3339),
Kind: m.Kind,
NoEcho: m.NoEcho,
Value: m.Value,
}
if m.Dest == "" {
var except group.Client
if m.NoEcho {
except = c
}
err := broadcast(g.GetClients(except), mm)
if err != nil {
log.Printf("broadcast(chat): %v", err)
}
} else {
cc := g.GetClient(m.Dest)
if cc == nil {
return c.error(group.UserError("user unknown"))
}
ccc, ok := cc.(*webClient)
if !ok {
return c.error(group.UserError(
"this user doesn't chat",
))
}
ccc.write(mm)
}
case "groupaction":
g := c.group
if g == nil {
return c.error(group.UserError("join a group first"))
}
switch m.Kind {
case "clearchat":
if !member("op", c.permissions) {
return c.error(group.UserError("not authorised"))
}
var id, userId string
if m.Value != nil {
value, ok := m.Value.(map[string]any)
if !ok {
return c.error(group.UserError(
"bad value in clearchat",
))
}
id, _ = value["id"].(string)
userId, _ = value["userId"].(string)
if userId == "" && id != "" {
return c.error(group.UserError(
"bad value in clearchat",
))
}
}
g.ClearChatHistory(id, userId)
m := clientMessage{
Type: "usermessage",
Kind: "clearchat",
Value: m.Value,
Privileged: true,
}
err := broadcast(g.GetClients(nil), m)
if err != nil {
log.Printf("broadcast(clearchat): %v", err)
}
case "lock", "unlock":
if !member("op", c.permissions) {
return c.error(group.UserError("not authorised"))
}
message := ""
v, ok := m.Value.(string)
if ok {
message = v
}
g.SetLocked(m.Kind == "lock", message)
case "record":
if !member("record", c.permissions) {
return c.error(group.UserError("not authorised"))
}
for _, cc := range g.GetClients(c) {
_, ok := cc.(*diskwriter.Client)
if ok {
return c.error(group.UserError("already recording"))
}
}
disk := diskwriter.New(g)
_, err := group.AddClient(g.Name(), disk,
group.ClientCredentials{
System: true,
},
)
if err != nil {
disk.Close()
return c.error(err)
}
requestConns(disk, c.group, "")
case "unrecord":
if !member("record", c.permissions) {
return c.error(group.UserError("not authorised"))
}
for _, cc := range g.GetClients(c) {
disk, ok := cc.(*diskwriter.Client)
if ok {
disk.Close()
group.DelClient(disk)
}
}
case "subgroups":
if !member("op", c.permissions) {
return c.error(group.UserError("not authorised"))
}
s := ""
for _, sg := range group.GetSubGroups(g.Name()) {
plural := ""
if sg.Clients > 1 {
plural = "s"
}
s = s + fmt.Sprintf("%v (%v client%v)\n",
sg.Name, sg.Clients, plural)
}
username := "Server"
c.write(clientMessage{
Type: "chat",
Dest: c.id,
Username: &username,
Time: time.Now().Format(time.RFC3339),
Value: s,
})
case "setdata":
if !member("op", c.permissions) {
return c.error(group.UserError("not authorised"))
}
data, ok := m.Value.(map[string]interface{})
if !ok {
return c.error(group.UserError(
"Bad value in setdata",
))
}
g.UpdateData(data)
case "maketoken":
terror := func(e, m string) error {
return c.write(clientMessage{
Type: "usermessage",
Kind: "token",
Privileged: true,
Error: e,
Value: m,
})
}
if !member("token", c.permissions) {
return terror("not-authorised", "not authorised")
}
tok, err := parseStatefulToken(m.Value)
if err != nil {
return terror("error", err.Error())
}
if tok.Token == "" {
buf := make([]byte, 8)
crand.Read(buf)
tok.Token =
base64.RawURLEncoding.EncodeToString(buf)
} else {
return terror("error", "client specified token")
}
if tok.Group != c.group.Name() {
return terror("error", "wrong group in token")
}
if tok.Expires == nil {
return terror("error", "token doesn't expire")
}
if tok.Username != nil &&
c.group.UserExists(*tok.Username) {
return terror("error", "that username is taken")
}
for _, p := range tok.Permissions {
if !member(p, c.permissions) {
return terror(
"not-authorised",
"not authorised",
)
}
}
user := c.username
if user != "" {
tok.IssuedBy = &user
}
now := time.Now().UTC()
tok.IssuedAt = &now
new, err := token.Update(tok, "")
if err != nil {
return terror("error", err.Error())
}
c.write(clientMessage{
Type: "usermessage",
Kind: "token",
Privileged: true,
Value: new,
})
case "edittoken":
terror := func(e, m string) error {
return c.write(clientMessage{
Type: "usermessage",
Kind: "token",
Privileged: true,
Error: e,
Value: m,
})
}
if !member("op", c.permissions) ||
!member("token", c.permissions) {
return terror("not-authorised", "not authorised")
}
tok, err := parseStatefulToken(m.Value)
if err != nil {
return terror("error", err.Error())
}
if tok.Group != "" || tok.Username != nil ||
tok.Permissions != nil ||
tok.IssuedBy != nil ||
tok.IssuedAt != nil {
return terror(
"error", "this field cannot be edited",
)
}
old, etag, err := token.Get(tok.Token)
if err != nil {
return terror("error", err.Error())
}
t := old.Clone()
if tok.Expires != nil {
t.Expires = tok.Expires
}
if tok.NotBefore != nil {
t.NotBefore = tok.NotBefore
}
new, err := token.Update(t, etag)
if err != nil {
return terror("error", err.Error())
}
c.write(clientMessage{
Type: "usermessage",
Kind: "token",
Privileged: true,
Value: new,
})
case "listtokens":
terror := func(e, m string) error {
return c.write(clientMessage{
Type: "usermessage",
Kind: "tokenlist",
Privileged: true,
Error: e,
Value: m,
})
}
if !member("op", c.permissions) ||
!member("token", c.permissions) {
return terror("not-authorised", "not authorised")
}
tokens, _, err := token.List(c.group.Name())
if err != nil {
return terror("error", err.Error())
}
c.write(clientMessage{
Type: "usermessage",
Kind: "tokenlist",
Privileged: true,
Value: tokens,
})
default:
return group.UserError("unknown group action")
}
case "useraction":
g := c.group
if g == nil {
return c.error(group.UserError("join a group first"))
}
switch m.Kind {
case "op", "unop", "present", "unpresent", "shutup", "unshutup":
if !member("op", c.permissions) {
return c.error(group.UserError("not authorised"))
}
err := setPermissions(g, m.Dest, m.Kind)
if err != nil {
return c.error(err)
}
case "identify":
if !member("op", c.permissions) {
return c.error(group.UserError("not authorised"))
}
d := g.GetClient(m.Dest)
if d == nil {
return c.error(
group.UserError("client not found"),
)
}
value := make(map[string]any)
value["id"] = d.Id()
if username := d.Username(); username != "" {
value["username"] = username
}
if addr := d.Addr(); addr != nil {
value["address"] = addr.String()
}
type warner interface {
Warn(bool, string) error
}
w, ok := d.(warner)
if ok {
w.Warn(false, "Your IP address has been "+
"communicated to user "+
c.Username()+".")
}
c.write(clientMessage{
Type: "usermessage",
Kind: "userinfo",
Privileged: true,
Value: value,
})
case "kick":
if !member("op", c.permissions) {
return c.error(group.UserError("not authorised"))
}
message := ""
v, ok := m.Value.(string)
if ok {
message = v
}
err := kickClient(g, m.Source, m.Username, m.Dest, message)
if err != nil {
return c.error(err)
}
case "setdata":
if m.Dest != c.Id() {
return c.error(group.UserError("not authorised"))
}
data, ok := m.Value.(map[string]interface{})
if !ok {
return c.error(group.UserError(
"Bad value in setdata",
))
}
if c.data == nil {
c.data = make(map[string]interface{})
}
for k, v := range data {
if v == nil {
delete(c.data, k)
} else {
c.data[k] = v
}
}
id := c.Id()
user := c.Username()
perms := c.Permissions()
data = c.Data()
go func(clients []group.Client) {
for _, cc := range clients {
cc.PushClient(
g.Name(), "change",
id, user, perms, data,
)
}
}(g.GetClients(nil))
default:
return group.UserError("unknown user action")
}
case "pong":
// nothing
case "ping":
return c.write(clientMessage{
Type: "pong",
})
default:
log.Printf("unexpected message: %v", m.Type)
return group.ProtocolError("unexpected message")
}
return nil
}
func parseStatefulToken(value interface{}) (*token.Stateful, error) {
data, ok := value.(map[string]interface{})
if !ok || data == nil {
return nil, errors.New("bad token value")
}
parseString := func(key string) (*string, error) {
v := data[key]
if v == nil {
return nil, nil
}
vv, ok := v.(string)
if !ok {
return nil, errors.New("bad string value")
}
return &vv, nil
}
parseStringList := func(key string) ([]string, error) {
v := data[key]
if v == nil {
return nil, nil
}
vv, ok := v.([]interface{})
if !ok {
return nil, errors.New("bad string list")
}
vvv := make([]string, 0, len(vv))
for _, s := range vv {
ss, ok := s.(string)
if !ok {
return nil, errors.New("bad string list")
}
vvv = append(vvv, ss)
}
return vvv, nil
}
parseTime := func(key string) (*time.Time, error) {
v := data[key]
if v == nil {
return nil, nil
}
switch v := v.(type) {
case string:
vv, err := time.Parse(time.RFC3339, v)
if err != nil {
return nil, errors.New("bad time value")
}
return &vv, nil
case float64: // relative time
vv := time.Now().Add(time.Duration(v) * time.Millisecond)
return &vv, nil
default:
return nil, errors.New("bad time value")
}
}
t, err := parseString("token")
if err != nil {
return nil, err
}
tt := ""
if t != nil {
tt = *t
}
u, err := parseString("username")
if err != nil {
return nil, err
}
g, err := parseString("group")
if err != nil {
return nil, err
}
gg := ""
if g != nil {
gg = *g
}
p, err := parseStringList("permissions")
if err != nil {
return nil, err
}
e, err := parseTime("expires")
if err != nil {
return nil, err
}
n, err := parseTime("not-before")
if err != nil {
return nil, err
}
return &token.Stateful{
Token: tt,
Group: gg,
Username: u,
Permissions: p,
Expires: e,
NotBefore: n,
}, nil
}
func clientReader(conn *websocket.Conn, read chan<- interface{}, done <-chan struct{}) {
defer close(read)
for {
var m clientMessage
err := conn.ReadJSON(&m)
if err != nil {
select {
case read <- err:
return
case <-done:
return
}
}
select {
case read <- m:
case <-done:
return
}
}
}
func clientWriter(conn *websocket.Conn, ch <-chan interface{}, done chan<- struct{}) {
defer func() {
close(done)
conn.Close()
}()
for {
m, ok := <-ch
if !ok {
break
}
err := conn.SetWriteDeadline(
time.Now().Add(500 * time.Millisecond),
)
if err != nil {
return
}
switch m := m.(type) {
case clientMessage:
err := conn.WriteJSON(m)
if err != nil {
return
}
case []byte:
err := conn.WriteMessage(websocket.TextMessage, m)
if err != nil {
return
}
case closeMessage:
if m.data != nil {
conn.WriteMessage(
websocket.CloseMessage,
m.data,
)
}
return
default:
log.Printf("clientWriter: unexpected message %T", m)
return
}
}
}
func (c *webClient) Warn(oponly bool, message string) error {
if oponly && !member("op", c.permissions) {
return nil
}
return c.write(clientMessage{
Type: "usermessage",
Kind: "warning",
Dest: c.id,
Privileged: true,
Value: message,
})
}
var ErrClientDead = errors.New("client is dead")
func (c *webClient) action(a interface{}) {
c.actions.Put(a)
}
func (c *webClient) write(m clientMessage) error {
select {
case c.writeCh <- m:
return nil
case <-c.writerDone:
return ErrClientDead
}
}
func broadcast(cs []group.Client, m clientMessage) error {
b, err := json.Marshal(m)
if err != nil {
return err
}
for _, c := range cs {
cc, ok := c.(*webClient)
if !ok {
continue
}
select {
case cc.writeCh <- b:
case <-cc.writerDone:
}
}
return nil
}
func (c *webClient) close(data []byte) error {
select {
case c.writeCh <- closeMessage{data}:
return nil
case <-c.writerDone:
return ErrClientDead
}
}
func errorMessage(id string, err error) *clientMessage {
switch e := err.(type) {
case group.UserError:
return &clientMessage{
Type: "usermessage",
Kind: "error",
Dest: id,
Privileged: true,
Value: e.Error(),
}
case group.KickError:
message := e.Message
if message == "" {
message = "you have been kicked out"
}
return &clientMessage{
Type: "usermessage",
Kind: "kicked",
Id: e.Id,
Username: e.Username,
Dest: id,
Privileged: true,
Value: message,
}
default:
return nil
}
}
func (c *webClient) error(err error) error {
m := errorMessage(c.id, err)
if m == nil {
return err
}
return c.write(*m)
}