1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-10 02:35:58 +01:00
galene/webclient.go
2020-06-01 00:02:17 +02:00

1738 lines
34 KiB
Go

// Copyright (c) 2020 by Juliusz Chroboczek.
// This is not open source software. Copy it, and I'll break into your
// house and tell your three year-old that Santa doesn't exist.
package main
import (
"encoding/json"
"errors"
"io"
"log"
"math"
"math/bits"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"sfu/estimator"
"sfu/jitter"
"sfu/mono"
"sfu/packetcache"
"github.com/gorilla/websocket"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v2"
)
var iceConf webrtc.Configuration
var iceOnce sync.Once
func iceConfiguration() webrtc.Configuration {
iceOnce.Do(func() {
var iceServers []webrtc.ICEServer
file, err := os.Open(iceFilename)
if err != nil {
log.Printf("Open %v: %v", iceFilename, err)
return
}
defer file.Close()
d := json.NewDecoder(file)
err = d.Decode(&iceServers)
if err != nil {
log.Printf("Get ICE configuration: %v", err)
return
}
iceConf = webrtc.Configuration{
ICEServers: iceServers,
}
})
return iceConf
}
type protocolError string
func (err protocolError) Error() string {
return string(err)
}
type userError string
func (err userError) Error() string {
return string(err)
}
func errorToWSCloseMessage(err error) (string, []byte) {
var code int
var text string
switch e := err.(type) {
case *websocket.CloseError:
code = websocket.CloseNormalClosure
case protocolError:
code = websocket.CloseProtocolError
text = string(e)
case userError:
code = websocket.CloseNormalClosure
text = string(e)
default:
code = websocket.CloseInternalServerErr
}
return text, websocket.FormatCloseMessage(code, text)
}
func isWSNormalError(err error) bool {
return websocket.IsCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway)
}
type webClient struct {
group *group
id string
username string
permissions userPermission
requested map[string]uint32
done chan struct{}
writeCh chan interface{}
writerDone chan struct{}
actionCh chan interface{}
mu sync.Mutex
down map[string]*rtpDownConnection
up map[string]*upConnection
}
func (c *webClient) getGroup() *group {
return c.group
}
func (c *webClient) getId() string {
return c.id
}
func (c *webClient) getUsername() string {
return c.username
}
func (c *webClient) pushClient(id, username string, add bool) error {
return c.write(clientMessage{
Type: "user",
Id: id,
Username: username,
Del: !add,
})
}
type rateMap map[string]uint32
func (v *rateMap) UnmarshalJSON(b []byte) error {
var m map[string]interface{}
err := json.Unmarshal(b, &m)
if err != nil {
return err
}
n := make(map[string]uint32, len(m))
for k, w := range m {
switch w := w.(type) {
case bool:
if w {
n[k] = ^uint32(0)
} else {
n[k] = 0
}
case float64:
if w < 0 || w >= float64(^uint32(0)) {
return errors.New("overflow")
}
n[k] = uint32(w)
default:
return errors.New("unexpected type in JSON map")
}
}
*v = n
return nil
}
func (v rateMap) MarshalJSON() ([]byte, error) {
m := make(map[string]interface{}, len(v))
for k, w := range v {
switch w {
case 0:
m[k] = false
case ^uint32(0):
m[k] = true
default:
m[k] = w
}
}
return json.Marshal(m)
}
type clientMessage struct {
Type string `json:"type"`
Id string `json:"id,omitempty"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Permissions userPermission `json:"permissions,omitempty"`
Group string `json:"group,omitempty"`
Value string `json:"value,omitempty"`
Me bool `json:"me,omitempty"`
Offer *webrtc.SessionDescription `json:"offer,omitempty"`
Answer *webrtc.SessionDescription `json:"answer,omitempty"`
Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
Del bool `json:"del,omitempty"`
Request rateMap `json:"request,omitempty"`
}
type closeMessage struct {
data []byte
}
func startClient(conn *websocket.Conn) (err error) {
var m clientMessage
err = conn.SetReadDeadline(time.Now().Add(15 * time.Second))
if err != nil {
conn.Close()
return
}
err = conn.ReadJSON(&m)
if err != nil {
conn.Close()
return
}
err = conn.SetReadDeadline(time.Time{})
if err != nil {
conn.Close()
return
}
if m.Type != "handshake" {
conn.Close()
return
}
if strings.ContainsRune(m.Username, ' ') {
err = userError("don't put spaces in your username")
return
}
c := &webClient{
id: m.Id,
username: m.Username,
actionCh: make(chan interface{}, 10),
done: make(chan struct{}),
}
defer close(c.done)
c.writeCh = make(chan interface{}, 25)
defer func() {
if isWSNormalError(err) {
err = nil
} else {
m, e := errorToWSCloseMessage(err)
if m != "" {
c.write(clientMessage{
Type: "error",
Value: m,
})
}
select {
case c.writeCh <- closeMessage{e}:
case <-c.writerDone:
}
}
close(c.writeCh)
c.writeCh = nil
}()
c.writerDone = make(chan struct{})
go clientWriter(conn, c.writeCh, c.writerDone)
g, err := addClient(m.Group, c, m.Username, m.Password)
if err != nil {
return
}
c.group = g
defer delClient(c)
return clientLoop(c, conn)
}
func getUpConn(c *webClient, id string) *upConnection {
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
return nil
}
conn := c.up[id]
if conn == nil {
return nil
}
return conn
}
func getUpConns(c *webClient) []string {
c.mu.Lock()
defer c.mu.Unlock()
up := make([]string, 0, len(c.up))
for id := range c.up {
up = append(up, id)
}
return up
}
func addUpConn(c *webClient, id string) (*upConnection, error) {
pc, err := groups.api.NewPeerConnection(iceConfiguration())
if err != nil {
return nil, err
}
_, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio,
webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly,
},
)
if err != nil {
pc.Close()
return nil, err
}
_, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo,
webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly,
},
)
if err != nil {
pc.Close()
return nil, err
}
conn := &upConnection{id: id, pc: pc}
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
c.up = make(map[string]*upConnection)
}
if c.up[id] != nil || (c.down != nil && c.down[id] != nil) {
conn.pc.Close()
return nil, errors.New("Adding duplicate connection")
}
c.up[id] = conn
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
sendICE(c, id, candidate)
})
go rtcpUpSender(c, conn)
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
c.mu.Lock()
u, ok := c.up[id]
if !ok {
log.Printf("Unknown connection")
c.mu.Unlock()
return
}
mid := getUpMid(pc, remote)
if mid == "" {
log.Printf("Couldn't get track's mid")
c.mu.Unlock()
return
}
label, ok := u.labels[mid]
if !ok {
log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
if isvideo {
label = "video"
} else {
label = "audio"
}
}
track := &upTrack{
track: remote,
label: label,
cache: packetcache.New(96),
rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate),
maxBitrate: ^uint64(0),
localCh: make(chan localTrackAction, 2),
writerDone: make(chan struct{}),
}
u.tracks = append(u.tracks, track)
var tracks []*upTrack
if u.complete() {
tracks = make([]*upTrack, len(u.tracks))
copy(tracks, u.tracks)
}
if remote.Kind() == webrtc.RTPCodecTypeVideo {
atomic.AddUint32(&c.group.videoCount, 1)
}
c.mu.Unlock()
go readLoop(conn, track)
go rtcpUpListener(conn, track, receiver)
if tracks != nil {
clients := c.group.getClients(c)
for _, cc := range clients {
cc.pushConn(u, tracks, u.label)
}
}
})
return conn, nil
}
type packetIndex struct {
seqno uint16
index uint16
}
func readLoop(conn *upConnection, track *upTrack) {
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
ch := make(chan packetIndex, 32)
defer close(ch)
go writeLoop(conn, track, ch)
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
drop := 0
for {
bytes, err := track.track.Read(buf)
if err != nil {
if err != io.EOF {
log.Printf("%v", err)
}
break
}
track.rate.Add(uint32(bytes))
err = packet.Unmarshal(buf[:bytes])
if err != nil {
log.Printf("%v", err)
continue
}
track.jitter.Accumulate(packet.Timestamp)
first, index :=
track.cache.Store(packet.SequenceNumber, buf[:bytes])
if packet.SequenceNumber-first > 24 {
found, first, bitmap := track.cache.BitmapGet()
if found {
err := conn.sendNACK(track, first, bitmap)
if err != nil {
log.Printf("%v", err)
}
}
}
if drop > 0 {
if packet.Marker {
// last packet in frame
drop = 0
} else {
drop--
}
continue
}
select {
case ch <- packetIndex{packet.SequenceNumber, index}:
default:
if isvideo {
// the writer is congested. Drop until
// the end of the frame.
if isvideo && !packet.Marker {
drop = 7
}
}
}
}
}
func writeLoop(conn *upConnection, track *upTrack, ch <-chan packetIndex) {
defer close(track.writerDone)
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
local := make([]downTrack, 0)
firSent := false
for {
select {
case action := <-track.localCh:
if action.add {
local = append(local, action.track)
firSent = false
} else {
found := false
for i, t := range local {
if t == action.track {
local = append(local[:i], local[i+1:]...)
found = true
break
}
}
if !found {
log.Printf("Deleting unknown track!")
}
}
case pi, ok := <-ch:
if !ok {
return
}
bytes := track.cache.GetAt(pi.seqno, pi.index, buf)
if bytes == 0 {
continue
}
err := packet.Unmarshal(buf[:bytes])
if err != nil {
log.Printf("%v", err)
continue
}
kfNeeded := false
for _, l := range local {
err := l.WriteRTP(&packet)
if err != nil {
if err == ErrKeyframeNeeded {
kfNeeded = true
} else if err != io.ErrClosedPipe {
log.Printf("WriteRTP: %v", err)
}
continue
}
l.Accumulate(uint32(bytes))
}
if kfNeeded {
err := conn.sendFIR(track, !firSent)
if err == ErrUnsupportedFeedback {
err := conn.sendPLI(track)
if err != nil &&
err != ErrUnsupportedFeedback {
log.Printf("sendPLI: %v", err)
}
} else if err != nil {
log.Printf("sendFIR: %v", err)
}
firSent = true
}
}
}
}
func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
for {
ps, err := r.ReadRTCP()
if err != nil {
if err != io.EOF {
log.Printf("ReadRTCP: %v", err)
}
return
}
for _, p := range ps {
switch p := p.(type) {
case *rtcp.SenderReport:
atomic.StoreUint32(&track.lastSenderReport,
uint32(p.NTPTime>>16))
atomic.StoreUint32(&track.lastSenderReportTime,
uint32(mono.Now(0x10000)))
case *rtcp.SourceDescription:
}
}
}
}
func sendRR(c *webClient, conn *upConnection) error {
c.mu.Lock()
if len(conn.tracks) == 0 {
c.mu.Unlock()
return nil
}
now := uint32(mono.Now(0x10000))
reports := make([]rtcp.ReceptionReport, 0, len(conn.tracks))
for _, t := range conn.tracks {
expected, lost, totalLost, eseqno := t.cache.GetStats(true)
if expected == 0 {
expected = 1
}
if lost >= expected {
lost = expected - 1
}
lastSR := atomic.LoadUint32(&t.lastSenderReport)
delay := now - atomic.LoadUint32(&t.lastSenderReportTime)
reports = append(reports, rtcp.ReceptionReport{
SSRC: t.track.SSRC(),
FractionLost: uint8((lost * 256) / expected),
TotalLost: totalLost,
LastSequenceNumber: eseqno,
Jitter: t.jitter.Jitter(),
LastSenderReport: lastSR,
Delay: delay,
})
}
c.mu.Unlock()
return conn.pc.WriteRTCP([]rtcp.Packet{
&rtcp.ReceiverReport{
SSRC: 1,
Reports: reports,
},
})
}
func rtcpUpSender(c *webClient, conn *upConnection) {
for {
time.Sleep(time.Second)
err := sendRR(c, conn)
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return
}
log.Printf("WriteRTCP: %v", err)
}
}
}
func delUpConn(c *webClient, id string) bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.up == nil {
return false
}
conn := c.up[id]
if conn == nil {
return false
}
for _, track := range conn.tracks {
if track.track.Kind() == webrtc.RTPCodecTypeVideo {
count := atomic.AddUint32(&c.group.videoCount,
^uint32(0))
if count == ^uint32(0) {
log.Printf("Negative track count!")
atomic.StoreUint32(&c.group.videoCount, 0)
}
}
}
conn.Close()
delete(c.up, id)
return true
}
func getDownConn(c *webClient, id string) *rtpDownConnection {
if c.down == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
conn := c.down[id]
if conn == nil {
return nil
}
return conn
}
func getConn(c *webClient, id string) iceConnection {
up := getUpConn(c, id)
if up != nil {
return up
}
down := getDownConn(c, id)
if down != nil {
return down
}
return nil
}
func addDownConn(c *webClient, id string, remote *upConnection) (*rtpDownConnection, error) {
pc, err := groups.api.NewPeerConnection(iceConfiguration())
if err != nil {
return nil, err
}
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
sendICE(c, id, candidate)
})
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
log.Printf("Got track on downstream connection")
})
if c.down == nil {
c.down = make(map[string]*rtpDownConnection)
}
conn := &rtpDownConnection{
id: id,
client: c,
pc: pc,
remote: remote,
}
c.mu.Lock()
defer c.mu.Unlock()
if c.down[id] != nil || (c.up != nil && c.up[id] != nil) {
conn.pc.Close()
return nil, errors.New("Adding duplicate connection")
}
err = remote.addLocal(conn)
if err != nil {
conn.pc.Close()
return nil, err
}
c.down[id] = conn
return conn, nil
}
func delDownConn(c *webClient, id string) bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.down == nil {
return false
}
conn := c.down[id]
if conn == nil {
return false
}
conn.remote.delLocal(conn)
for _, track := range conn.tracks {
// we only insert the track after we get an answer, so
// ignore errors here.
track.remote.delLocal(track)
}
conn.pc.Close()
delete(c.down, id)
return true
}
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack *upTrack, remoteConn *upConnection) (*webrtc.RTPSender, error) {
local, err := conn.pc.NewTrack(
remoteTrack.track.PayloadType(),
remoteTrack.track.SSRC(),
remoteTrack.track.ID(),
remoteTrack.track.Label(),
)
if err != nil {
return nil, err
}
s, err := conn.pc.AddTrack(local)
if err != nil {
return nil, err
}
track := &rtpDownTrack{
track: local,
remote: remoteTrack,
maxLossBitrate: new(bitrate),
maxREMBBitrate: new(bitrate),
stats: new(receiverStats),
rate: estimator.New(time.Second),
}
conn.tracks = append(conn.tracks, track)
go rtcpDownListener(conn, track, s)
return s, nil
}
const (
minLossRate = 9600
initLossRate = 512 * 1000
maxLossRate = 1 << 30
)
func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
rate := track.maxLossBitrate.Get(now)
if rate > maxLossRate {
// no recent feedback, reset
rate = initLossRate
}
if loss < 5 {
// if our actual rate is low, then we're not probing the
// bottleneck
actual := 8 * uint64(track.rate.Estimate())
if actual >= (rate*7)/8 {
// loss < 0.02, multiply by 1.05
rate = rate * 269 / 256
if rate > maxLossRate {
rate = maxLossRate
}
}
} else if loss > 25 {
// loss > 0.1, multiply by (1 - loss/2)
rate = rate * (512 - uint64(loss)) / 512
if rate < minLossRate {
rate = minLossRate
}
}
// update unconditionally, to set the timestamp
track.maxLossBitrate.Set(rate, now)
}
func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) {
var gotFir bool
lastFirSeqno := uint8(0)
for {
ps, err := s.ReadRTCP()
if err != nil {
if err != io.EOF {
log.Printf("ReadRTCP: %v", err)
}
return
}
for _, p := range ps {
switch p := p.(type) {
case *rtcp.PictureLossIndication:
err := conn.remote.sendPLI(track.remote)
if err != nil {
log.Printf("sendPLI: %v", err)
}
case *rtcp.FullIntraRequest:
found := false
var seqno uint8
for _, entry := range p.FIR {
if entry.SSRC == track.track.SSRC() {
found = true
seqno = entry.SequenceNumber
break
}
}
if !found {
log.Printf("Misdirected FIR")
continue
}
increment := true
if gotFir {
increment = seqno != lastFirSeqno
}
gotFir = true
lastFirSeqno = seqno
err := conn.remote.sendFIR(
track.remote, increment,
)
if err == ErrUnsupportedFeedback {
err := conn.remote.sendPLI(track.remote)
if err != nil {
log.Printf("sendPLI: %v", err)
}
} else if err != nil {
log.Printf("sendFIR: %v", err)
}
case *rtcp.ReceiverEstimatedMaximumBitrate:
track.maxREMBBitrate.Set(
p.Bitrate, mono.Microseconds(),
)
case *rtcp.ReceiverReport:
for _, r := range p.Reports {
if r.SSRC == track.track.SSRC() {
now := mono.Microseconds()
track.stats.Set(
r.FractionLost,
r.Jitter,
now,
)
track.updateRate(
r.FractionLost,
now,
)
}
}
case *rtcp.TransportLayerNack:
maxBitrate := track.GetMaxBitrate(
mono.Microseconds(),
)
bitrate := track.rate.Estimate()
if uint64(bitrate)*7/8 < maxBitrate {
sendRecovery(p, track)
}
}
}
}
}
func trackKinds(down *rtpDownConnection) (audio bool, video bool) {
if down.pc == nil {
return
}
for _, s := range down.pc.GetSenders() {
track := s.Track()
if track == nil {
continue
}
switch track.Kind() {
case webrtc.RTPCodecTypeAudio:
audio = true
case webrtc.RTPCodecTypeVideo:
video = true
}
}
return
}
func updateUpBitrate(up *upConnection, maxVideoRate uint64) {
now := mono.Microseconds()
for _, track := range up.tracks {
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
minrate := uint64(minAudioRate)
rate := ^uint64(0)
if isvideo {
minrate = minVideoRate
rate = maxVideoRate
if rate < minrate {
rate = minrate
}
}
local := track.getLocal()
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
}
track.maxBitrate = rate
}
}
var ErrUnsupportedFeedback = errors.New("unsupported feedback type")
var ErrRateLimited = errors.New("rate limited")
func (up *upConnection) sendPLI(track *upTrack) error {
if !track.hasRtcpFb("nack", "pli") {
return ErrUnsupportedFeedback
}
last := atomic.LoadUint64(&track.lastPLI)
now := mono.Microseconds()
if now >= last && now-last < 200000 {
return ErrRateLimited
}
atomic.StoreUint64(&track.lastPLI, now)
return sendPLI(up.pc, track.track.SSRC())
}
func sendPLI(pc *webrtc.PeerConnection, ssrc uint32) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.PictureLossIndication{MediaSSRC: ssrc},
})
}
func (up *upConnection) sendFIR(track *upTrack, increment bool) error {
// we need to reliably increment the seqno, even if we are going
// to drop the packet due to rate limiting.
var seqno uint8
if increment {
seqno = uint8(atomic.AddUint32(&track.firSeqno, 1) & 0xFF)
} else {
seqno = uint8(atomic.LoadUint32(&track.firSeqno) & 0xFF)
}
if !track.hasRtcpFb("ccm", "fir") {
return ErrUnsupportedFeedback
}
last := atomic.LoadUint64(&track.lastFIR)
now := mono.Microseconds()
if now >= last && now-last < 200000 {
return ErrRateLimited
}
atomic.StoreUint64(&track.lastFIR, now)
return sendFIR(up.pc, track.track.SSRC(), seqno)
}
func sendFIR(pc *webrtc.PeerConnection, ssrc uint32, seqno uint8) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.FullIntraRequest{
FIR: []rtcp.FIREntry{
rtcp.FIREntry{
SSRC: ssrc,
SequenceNumber: seqno,
},
},
},
})
}
func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: bitrate,
SSRCs: []uint32{ssrc},
},
})
}
func (up *upConnection) sendNACK(track *upTrack, first uint16, bitmap uint16) error {
if !track.hasRtcpFb("nack", "") {
return nil
}
err := sendNACK(up.pc, track.track.SSRC(), first, bitmap)
if err == nil {
track.cache.Expect(1 + bits.OnesCount16(bitmap))
}
return err
}
func sendNACK(pc *webrtc.PeerConnection, ssrc uint32, first uint16, bitmap uint16) error {
packet := rtcp.Packet(
&rtcp.TransportLayerNack{
MediaSSRC: ssrc,
Nacks: []rtcp.NackPair{
rtcp.NackPair{
first,
rtcp.PacketBitmap(bitmap),
},
},
},
)
return pc.WriteRTCP([]rtcp.Packet{packet})
}
func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) {
var packet rtp.Packet
buf := make([]byte, packetcache.BufSize)
for _, nack := range p.Nacks {
for _, seqno := range nack.PacketList() {
l := track.remote.cache.Get(seqno, buf)
if l == 0 {
continue
}
err := packet.Unmarshal(buf[:l])
if err != nil {
continue
}
err = track.track.WriteRTP(&packet)
if err != nil {
log.Printf("WriteRTP: %v", err)
continue
}
track.rate.Add(uint32(l))
}
}
}
func negotiate(c *webClient, down *rtpDownConnection) error {
offer, err := down.pc.CreateOffer(nil)
if err != nil {
return err
}
err = down.pc.SetLocalDescription(offer)
if err != nil {
return err
}
labels := make(map[string]string)
for _, t := range down.pc.GetTransceivers() {
var track *webrtc.Track
if t.Sender() != nil {
track = t.Sender().Track()
}
if track == nil {
continue
}
for _, tr := range down.tracks {
if tr.track == track {
labels[t.Mid()] = tr.remote.label
}
}
}
return c.write(clientMessage{
Type: "offer",
Id: down.id,
Offer: &offer,
Labels: labels,
})
}
func sendICE(c *webClient, id string, candidate *webrtc.ICECandidate) error {
if candidate == nil {
return nil
}
cand := candidate.ToJSON()
return c.write(clientMessage{
Type: "ice",
Id: id,
Candidate: &cand,
})
}
func gotOffer(c *webClient, id string, offer webrtc.SessionDescription, labels map[string]string) error {
var err error
up, ok := c.up[id]
if !ok {
up, err = addUpConn(c, id)
if err != nil {
return err
}
}
if c.username != "" {
up.label = c.username
}
err = up.pc.SetRemoteDescription(offer)
if err != nil {
return err
}
answer, err := up.pc.CreateAnswer(nil)
if err != nil {
return err
}
err = up.pc.SetLocalDescription(answer)
if err != nil {
return err
}
up.labels = labels
err = up.flushICECandidates()
if err != nil {
log.Printf("ICE: %v", err)
}
return c.write(clientMessage{
Type: "answer",
Id: id,
Answer: &answer,
})
}
func gotAnswer(c *webClient, id string, answer webrtc.SessionDescription) error {
down := getDownConn(c, id)
if down == nil {
return protocolError("unknown id in answer")
}
err := down.pc.SetRemoteDescription(answer)
if err != nil {
return err
}
err = down.flushICECandidates()
if err != nil {
log.Printf("ICE: %v", err)
}
for _, t := range down.tracks {
t.remote.addLocal(t)
}
return nil
}
func gotICE(c *webClient, candidate *webrtc.ICECandidateInit, id string) error {
conn := getConn(c, id)
if conn == nil {
return errors.New("unknown id in ICE")
}
return conn.addICECandidate(candidate)
}
func (c *webClient) setRequested(requested map[string]uint32) error {
if c.down != nil {
for id := range c.down {
c.write(clientMessage{
Type: "close",
Id: id,
})
delDownConn(c, id)
}
}
c.requested = requested
go pushConns(c)
return nil
}
func pushConns(c client) {
clients := c.getGroup().getClients(c)
for _, cc := range clients {
ccc, ok := cc.(*webClient)
if ok {
ccc.action(pushConnsAction{c})
}
}
}
func (c *webClient) isRequested(label string) bool {
return c.requested[label] != 0
}
func addDownConnTracks(c *webClient, remote *upConnection, tracks []*upTrack) (*rtpDownConnection, error) {
requested := false
for _, t := range tracks {
if c.isRequested(t.label) {
requested = true
break
}
}
if !requested {
return nil, nil
}
down, err := addDownConn(c, remote.id, remote)
if err != nil {
return nil, err
}
for _, t := range tracks {
if !c.isRequested(t.label) {
continue
}
_, err = addDownTrack(c, down, t, remote)
if err != nil {
delDownConn(c, down.id)
return nil, err
}
}
return down, nil
}
func (c *webClient) pushConn(conn *upConnection, tracks []*upTrack, label string) error {
err := c.action(addConnAction{conn, tracks})
if err != nil {
return err
}
if label != "" {
err := c.action(addLabelAction{conn.id, conn.label})
if err != nil {
return err
}
}
return nil
}
func clientLoop(c *webClient, conn *websocket.Conn) error {
read := make(chan interface{}, 1)
go clientReader(conn, read, c.done)
defer func() {
c.setRequested(map[string]uint32{})
if c.up != nil {
for id := range c.up {
delUpConn(c, id)
}
}
}()
c.write(clientMessage{
Type: "permissions",
Permissions: c.permissions,
})
h := c.group.getChatHistory()
for _, m := range h {
err := c.write(clientMessage{
Type: "chat",
Id: m.id,
Username: m.user,
Value: m.value,
Me: m.me,
})
if err != nil {
return err
}
}
readTime := time.Now()
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
slowTicker := time.NewTicker(10 * time.Second)
defer slowTicker.Stop()
for {
select {
case m, ok := <-read:
if !ok {
return errors.New("reader died")
}
switch m := m.(type) {
case clientMessage:
readTime = time.Now()
err := handleClientMessage(c, m)
if err != nil {
return err
}
case error:
return m
}
case a := <-c.actionCh:
switch a := a.(type) {
case addConnAction:
down, err := addDownConnTracks(
c, a.conn, a.tracks,
)
if err != nil {
return err
}
if down != nil {
err = negotiate(c, down)
if err != nil {
log.Printf("Negotiate: %v", err)
delDownConn(c, down.id)
err = failConnection(
c, down.id,
"negotiation failed",
)
if err != nil {
return err
}
continue
}
}
case delConnAction:
found := delDownConn(c, a.id)
if found {
c.write(clientMessage{
Type: "close",
Id: a.id,
})
}
case addLabelAction:
c.write(clientMessage{
Type: "label",
Id: a.id,
Value: a.label,
})
case pushConnsAction:
for _, u := range c.up {
tracks := make([]*upTrack, len(u.tracks))
copy(tracks, u.tracks)
go a.c.pushConn(u, tracks, u.label)
}
case connectionFailedAction:
found := delUpConn(c, a.id)
if found {
err := failConnection(c, a.id,
"connection failed")
if err != nil {
return err
}
continue
}
// What should we do if a downstream
// connection fails? Renegotiate?
case permissionsChangedAction:
c.write(clientMessage{
Type: "permissions",
Permissions: c.permissions,
})
if !c.permissions.Present {
ids := getUpConns(c)
for _, id := range ids {
found := delUpConn(c, id)
if found {
failConnection(
c, id,
"permission denied",
)
}
}
}
case kickAction:
return userError("you have been kicked")
default:
log.Printf("unexpected action %T", a)
return errors.New("unexpected action")
}
case <-ticker.C:
sendRateUpdate(c)
case <-slowTicker.C:
if time.Since(readTime) > 90*time.Second {
return errors.New("client is dead")
}
if time.Since(readTime) > 60*time.Second {
err := c.write(clientMessage{
Type: "ping",
})
if err != nil {
return err
}
}
}
}
}
func failConnection(c *webClient, id string, message string) error {
if id != "" {
err := c.write(clientMessage{
Type: "abort",
Id: id,
})
if err != nil {
return err
}
}
if message != "" {
err := c.error(userError(message))
if err != nil {
return err
}
}
return nil
}
func setPermissions(g *group, id string, perm string) error {
g.mu.Lock()
defer g.mu.Unlock()
client := g.getClientUnlocked(id)
if client == nil {
return userError("no such user")
}
c, ok := client.(*webClient)
if !ok {
return userError("this is not a real user")
}
switch perm {
case "op":
c.permissions.Op = true
if g.description.AllowRecording {
c.permissions.Record = true
}
case "unop":
c.permissions.Op = false
c.permissions.Record = false
case "present":
c.permissions.Present = true
case "unpresent":
c.permissions.Present = false
default:
return userError("unknown permission")
}
return c.action(permissionsChangedAction{})
}
func kickClient(g *group, id string) error {
g.mu.Lock()
defer g.mu.Unlock()
client := g.getClientUnlocked(id)
if client == nil {
return userError("no such user")
}
c, ok := client.(*webClient)
if !ok {
return userError("this is not a real user")
}
return c.action(kickAction{})
}
func handleClientMessage(c *webClient, m clientMessage) error {
switch m.Type {
case "request":
err := c.setRequested(m.Request)
if err != nil {
return err
}
case "offer":
if !c.permissions.Present {
c.write(clientMessage{
Type: "abort",
Id: m.Id,
})
return c.error(userError("not authorised"))
}
if m.Offer == nil {
return protocolError("null offer")
}
err := gotOffer(c, m.Id, *m.Offer, m.Labels)
if err != nil {
log.Printf("gotOffer: %v", err)
return failConnection(c, m.Id, "negotiation failed")
}
case "answer":
if m.Answer == nil {
return protocolError("null answer")
}
err := gotAnswer(c, m.Id, *m.Answer)
if err != nil {
return err
}
case "close":
found := delUpConn(c, m.Id)
if !found {
log.Printf("Deleting unknown up connection %v", m.Id)
}
case "ice":
if m.Candidate == nil {
return protocolError("null candidate")
}
err := gotICE(c, m.Candidate, m.Id)
if err != nil {
log.Printf("ICE: %v", err)
}
case "chat":
c.group.addToChatHistory(m.Id, m.Username, m.Value, m.Me)
clients := c.group.getClients(c)
for _, cc := range clients {
cc, ok := cc.(*webClient)
if ok {
cc.write(m)
}
}
case "clearchat":
c.group.clearChatHistory()
m := clientMessage{Type: "clearchat"}
clients := c.group.getClients(nil)
for _, cc := range clients {
cc, ok := cc.(*webClient)
if ok {
cc.write(m)
}
}
case "op", "unop", "present", "unpresent":
if !c.permissions.Op {
return c.error(userError("not authorised"))
}
err := setPermissions(c.group, m.Id, m.Type)
if err != nil {
return c.error(err)
}
case "lock", "unlock":
if !c.permissions.Op {
return c.error(userError("not authorised"))
}
var locked uint32
if m.Type == "lock" {
locked = 1
}
atomic.StoreUint32(&c.group.locked, locked)
case "record":
if !c.permissions.Record {
return c.error(userError("not authorised"))
}
for _, cc := range c.group.getClients(c) {
_, ok := cc.(*diskClient)
if ok {
return c.error(userError("already recording"))
}
}
disk := &diskClient{
group: c.group,
id: "recording",
}
_, err := addClient(c.group.name, disk, "", "")
if err != nil {
disk.Close()
return c.error(err)
}
go pushConns(disk)
case "unrecord":
if !c.permissions.Record {
return c.error(userError("not authorised"))
}
for _, cc := range c.group.getClients(c) {
disk, ok := cc.(*diskClient)
if ok {
disk.Close()
delClient(disk)
}
}
case "kick":
if !c.permissions.Op {
return c.error(userError("not authorised"))
}
err := kickClient(c.group, m.Id)
if err != nil {
return c.error(err)
}
case "pong":
// nothing
case "ping":
c.write(clientMessage{
Type: "pong",
})
default:
log.Printf("unexpected message: %v", m.Type)
return protocolError("unexpected message")
}
return nil
}
func sendRateUpdate(c *webClient) {
type remb struct {
pc *webrtc.PeerConnection
ssrc uint32
bitrate uint64
}
rembs := make([]remb, 0)
maxVideoRate := ^uint64(0)
count := atomic.LoadUint32(&c.group.videoCount)
if count >= 3 {
maxVideoRate = uint64(2000000 / math.Sqrt(float64(count)))
if maxVideoRate < minVideoRate {
maxVideoRate = minVideoRate
}
}
c.mu.Lock()
for _, u := range c.up {
updateUpBitrate(u, maxVideoRate)
for _, t := range u.tracks {
if !t.hasRtcpFb("goog-remb", "") {
continue
}
bitrate := t.maxBitrate
if bitrate == ^uint64(0) {
continue
}
rembs = append(rembs,
remb{u.pc, t.track.SSRC(), bitrate})
}
}
c.mu.Unlock()
for _, r := range rembs {
err := sendREMB(r.pc, r.ssrc, r.bitrate)
if err != nil {
log.Printf("sendREMB: %v", err)
}
}
}
func clientReader(conn *websocket.Conn, read chan<- interface{}, done <-chan struct{}) {
defer close(read)
for {
var m clientMessage
err := conn.ReadJSON(&m)
if err != nil {
select {
case read <- err:
return
case <-done:
return
}
}
select {
case read <- m:
case <-done:
return
}
}
}
func clientWriter(conn *websocket.Conn, ch <-chan interface{}, done chan<- struct{}) {
defer func() {
close(done)
conn.Close()
}()
for {
m, ok := <-ch
if !ok {
break
}
err := conn.SetWriteDeadline(
time.Now().Add(2 * time.Second))
if err != nil {
return
}
switch m := m.(type) {
case clientMessage:
err := conn.WriteJSON(m)
if err != nil {
return
}
case closeMessage:
err := conn.WriteMessage(websocket.CloseMessage, m.data)
if err != nil {
return
}
default:
log.Printf("clientWiter: unexpected message %T", m)
return
}
}
}
var ErrWriterDead = errors.New("client writer died")
var ErrClientDead = errors.New("client dead")
func (c *webClient) action(m interface{}) error {
select {
case c.actionCh <- m:
return nil
case <-c.done:
return ErrClientDead
}
}
func (c *webClient) write(m clientMessage) error {
select {
case c.writeCh <- m:
return nil
case <-c.writerDone:
return ErrWriterDead
}
}
func (c *webClient) error(err error) error {
switch e := err.(type) {
case userError:
return c.write(clientMessage{
Type: "error",
Value: string(e),
})
default:
return err
}
}