From b26a8cad786b225986d7f0a36302a0c12d3f4432 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Sun, 17 May 2020 22:31:29 +0200 Subject: [PATCH] Label tracks explicitly. For now, this is only used to request screen sharing as opposed to normal videos. In the future, it will be used for simulcasting. --- client.go | 89 +++++++++++++++++++------------------------------ group.go | 46 ++++++++++++++++--------- static/sfu.css | 5 +++ static/sfu.html | 8 +++-- static/sfu.js | 36 ++++++++++++++++---- 5 files changed, 106 insertions(+), 78 deletions(-) diff --git a/client.go b/client.go index 74bf645..9c3b2c5 100644 --- a/client.go +++ b/client.go @@ -26,7 +26,6 @@ import ( "github.com/gorilla/websocket" "github.com/pion/rtcp" "github.com/pion/rtp" - "github.com/pion/sdp" "github.com/pion/webrtc/v2" ) @@ -104,6 +103,7 @@ type clientMessage struct { Offer *webrtc.SessionDescription `json:"offer,omitempty"` Answer *webrtc.SessionDescription `json:"answer,omitempty"` Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"` + Labels map[string]string `json:"labels,omitempty"` Del bool `json:"del,omitempty"` Request []string `json:"request,omitempty"` } @@ -301,13 +301,14 @@ func addUpConn(c *client, id string) (*upConnection, error) { } track := &upTrack{ track: remote, + label: u.labels[remote.ID()], cache: packetcache.New(96), rate: estimator.New(time.Second), jitter: jitter.New(remote.Codec().ClockRate), maxBitrate: ^uint64(0), } u.tracks = append(u.tracks, track) - done := len(u.tracks) >= u.trackCount + done := u.complete() if remote.Kind() == webrtc.RTPCodecTypeVideo { atomic.AddUint32(&c.group.videoCount, 1) } @@ -830,28 +831,27 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) { } } -func countMediaStreams(data string) (int, error) { - desc := sdp.NewJSEPSessionDescription(false) - err := desc.Unmarshal(data) +func negotiate(c *client, id string, down *downConnection) error { + offer, err := down.pc.CreateOffer(nil) if err != nil { - return 0, err + return err } - return len(desc.MediaDescriptions), nil -} -func negotiate(c *client, id string, pc *webrtc.PeerConnection) error { - offer, err := pc.CreateOffer(nil) + err = down.pc.SetLocalDescription(offer) if err != nil { return err } - err = pc.SetLocalDescription(offer) - if err != nil { - return err + + labels := make(map[string]string) + for _, t := range down.tracks { + labels[t.track.ID()] = t.remote.label } + return c.write(clientMessage{ - Type: "offer", - Id: id, - Offer: &offer, + Type: "offer", + Id: id, + Offer: &offer, + Labels: labels, }) } @@ -867,7 +867,7 @@ func sendICE(c *client, id string, candidate *webrtc.ICECandidate) error { }) } -func gotOffer(c *client, offer webrtc.SessionDescription, id string) error { +func gotOffer(c *client, id string, offer webrtc.SessionDescription, labels map[string]string) error { var err error up, ok := c.up[id] if !ok { @@ -879,12 +879,6 @@ func gotOffer(c *client, offer webrtc.SessionDescription, id string) error { if c.username != "" { up.label = c.username } - n, err := countMediaStreams(offer.SDP) - if err != nil { - log.Printf("Couldn't parse SDP: %v", err) - n = 2 - } - up.trackCount = n err = up.pc.SetRemoteDescription(offer) if err != nil { return err @@ -900,6 +894,8 @@ func gotOffer(c *client, offer webrtc.SessionDescription, id string) error { return err } + up.labels = labels + return c.write(clientMessage{ Type: "answer", Id: id, @@ -907,7 +903,7 @@ func gotOffer(c *client, offer webrtc.SessionDescription, id string) error { }) } -func gotAnswer(c *client, answer webrtc.SessionDescription, id string) error { +func gotAnswer(c *client, id string, answer webrtc.SessionDescription) error { conn := getDownConn(c, id) if conn == nil { return protocolError("unknown id in answer") @@ -934,11 +930,7 @@ func gotICE(c *client, candidate *webrtc.ICECandidateInit, id string) error { return pc.AddICECandidate(*candidate) } -func (c *client) setRequested(audio, video bool) error { - if audio == c.requestedAudio && video == c.requestedVideo { - return nil - } - +func (c *client) setRequested(requested []string) error { if c.down != nil { for id := range c.down { c.write(clientMessage{ @@ -949,8 +941,7 @@ func (c *client) setRequested(audio, video bool) error { } } - c.requestedAudio = audio - c.requestedVideo = video + c.requested = requested go func() { clients := c.group.getClients(c) @@ -962,15 +953,13 @@ func (c *client) setRequested(audio, video bool) error { return nil } -func (c *client) requested(kind webrtc.RTPCodecType) bool { - switch kind { - case webrtc.RTPCodecTypeAudio: - return c.requestedAudio - case webrtc.RTPCodecTypeVideo: - return c.requestedVideo - default: - return false +func (c *client) isRequested(label string) bool { + for _, r := range c.requested { + if label == r { + return true + } } + return false } func pushTracks(c *client, conn *upConnection, tracks []*upTrack, done bool, label string) { @@ -989,7 +978,7 @@ func clientLoop(c *client, conn *websocket.Conn) error { go clientReader(conn, read, c.done) defer func() { - c.setRequested(false, false) + c.setRequested([]string{}) if c.up != nil { for id := range c.up { delUpConn(c, id) @@ -1044,7 +1033,7 @@ func clientLoop(c *client, conn *websocket.Conn) error { case addTrackAction: var down *downConnection var err error - if c.requested(a.track.track.Kind()) { + if c.isRequested(a.track.label) { down, _, err = addDownTrack( c, a.remote.id, a.track, a.remote) @@ -1055,7 +1044,7 @@ func clientLoop(c *client, conn *websocket.Conn) error { down = getDownConn(c, a.remote.id) } if a.done && down != nil { - err = negotiate(c, a.remote.id, down.pc) + err = negotiate(c, a.remote.id, down) if err != nil { return err } @@ -1080,7 +1069,7 @@ func clientLoop(c *client, conn *websocket.Conn) error { copy(tracks, u.tracks) go pushTracks( a.c, u, tracks, - len(tracks) >= u.trackCount-1, + u.complete(), u.label, ) } @@ -1141,15 +1130,7 @@ func clientLoop(c *client, conn *websocket.Conn) error { func handleClientMessage(c *client, m clientMessage) error { switch m.Type { case "request": - var audio, video bool - for _, s := range m.Request { - switch(s) { - case "audio": audio = true - case "video": video = true - default: log.Printf("Unknown request %v", s) - } - } - err := c.setRequested(audio, video) + err := c.setRequested(m.Request) if err != nil { return err } @@ -1164,7 +1145,7 @@ func handleClientMessage(c *client, m clientMessage) error { if m.Offer == nil { return protocolError("null offer") } - err := gotOffer(c, *m.Offer, m.Id) + err := gotOffer(c, m.Id, *m.Offer, m.Labels) if err != nil { return err } @@ -1172,7 +1153,7 @@ func handleClientMessage(c *client, m clientMessage) error { if m.Answer == nil { return protocolError("null answer") } - err := gotAnswer(c, *m.Answer, m.Id) + err := gotAnswer(c, m.Id, *m.Answer) if err != nil { return err } diff --git a/group.go b/group.go index 521ea7c..25423c7 100644 --- a/group.go +++ b/group.go @@ -26,6 +26,7 @@ import ( type upTrack struct { track *webrtc.Track + label string rate *estimator.Estimator cache *packetcache.Cache jitter *jitter.Estimator @@ -74,11 +75,27 @@ func (up *upTrack) hasRtcpFb(tpe, parameter string) bool { } type upConnection struct { - id string - label string - pc *webrtc.PeerConnection - trackCount int - tracks []*upTrack + id string + label string + pc *webrtc.PeerConnection + tracks []*upTrack + labels map[string]string +} + +func (up *upConnection) complete() bool { + for id, _ := range up.labels { + found := false + for _, t := range up.tracks { + if t.track.ID() == id { + found = true + break + } + } + if !found { + return false + } + } + return true } type bitrate struct { @@ -150,16 +167,15 @@ type downConnection struct { } type client struct { - group *group - id string - username string - permissions userPermission - requestedAudio bool - requestedVideo bool - done chan struct{} - writeCh chan interface{} - writerDone chan struct{} - actionCh chan interface{} + group *group + id string + username string + permissions userPermission + requested []string + done chan struct{} + writeCh chan interface{} + writerDone chan struct{} + actionCh chan interface{} mu sync.Mutex down map[string]*downConnection diff --git a/static/sfu.css b/static/sfu.css index c558e21..3547240 100644 --- a/static/sfu.css +++ b/static/sfu.css @@ -65,6 +65,11 @@ h1 { margin-right: 0.4em; } +#requestselect { + width: 8em; + text-align-last: center; +} + #main { display: flex; } diff --git a/static/sfu.html b/static/sfu.html index 37322b8..4363e77 100644 --- a/static/sfu.html +++ b/static/sfu.html @@ -47,8 +47,12 @@ - - + + diff --git a/static/sfu.js b/static/sfu.js index ed0693a..65182f1 100644 --- a/static/sfu.js +++ b/static/sfu.js @@ -39,6 +39,7 @@ function Connection(id, pc) { this.label = null; this.pc = pc; this.stream = null; + this.labels = {}; this.iceCandidates = []; this.timers = []; this.audioStats = {}; @@ -140,9 +141,9 @@ document.getElementById('sharebox').onchange = function(e) { setShareMedia(this.checked); }; -document.getElementById('requestbox').onchange = function(e) { +document.getElementById('requestselect').onchange = function(e) { e.preventDefault(); - sendRequest(this.checked); + sendRequest(this.value); }; async function updateStats(conn, sender) { @@ -312,6 +313,7 @@ async function setLocalMedia(setup) { let c = up[localMediaId]; c.stream = stream; stream.getTracks().forEach(t => { + c.labels[t.id] = t.kind let sender = c.pc.addTrack(t, stream); c.setInterval(() => { updateStats(c, sender); @@ -358,6 +360,7 @@ async function setShareMedia(setup) { document.getElementById('sharebox').checked = false; setShareMedia(false); }; + c.labels[t.id] = 'screenshare'; c.setInterval(() => { updateStats(c, sender); }, 2000); @@ -485,7 +488,7 @@ function serverConnect() { username: up.username, password: up.password, }); - sendRequest(document.getElementById('requestbox').checked); + sendRequest(document.getElementById('requestselect').value); resolve(); }; socket.onclose = function(e) { @@ -508,7 +511,7 @@ function serverConnect() { let m = JSON.parse(e.data); switch(m.type) { case 'offer': - gotOffer(m.id, m.offer); + gotOffer(m.id, m.labels, m.offer); break; case 'answer': gotAnswer(m.id, m.answer); @@ -556,14 +559,30 @@ function serverConnect() { }); } -function sendRequest(video) { +function sendRequest(value) { + let request = []; + switch(value) { + case 'audio': + request = ['audio']; + break; + case 'screenshare': + request = ['audio', 'screenshare']; + break; + case 'everything': + request = ['audio', 'screenshare', 'video']; + break; + default: + console.error(`Uknown value ${value} in sendRequest`); + break; + } + send({ type: 'request', - request: video ? ['audio', 'video'] : ['audio'], + request: request, }); } -async function gotOffer(id, offer) { +async function gotOffer(id, labels, offer) { let c = down[id]; if(!c) { let pc = new RTCPeerConnection({ @@ -587,6 +606,8 @@ async function gotOffer(id, offer) { }; } + c.labels = labels; + await c.pc.setRemoteDescription(offer); await addIceCandidates(c); let answer = await c.pc.createAnswer(); @@ -1007,6 +1028,7 @@ async function negotiate(id) { send({ type: 'offer', id: id, + labels: c.labels, offer: offer, }); }