diff --git a/client.go b/client.go
index 74bf645..9c3b2c5 100644
--- a/client.go
+++ b/client.go
@@ -26,7 +26,6 @@ import (
"github.com/gorilla/websocket"
"github.com/pion/rtcp"
"github.com/pion/rtp"
- "github.com/pion/sdp"
"github.com/pion/webrtc/v2"
)
@@ -104,6 +103,7 @@ type clientMessage struct {
Offer *webrtc.SessionDescription `json:"offer,omitempty"`
Answer *webrtc.SessionDescription `json:"answer,omitempty"`
Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"`
+ Labels map[string]string `json:"labels,omitempty"`
Del bool `json:"del,omitempty"`
Request []string `json:"request,omitempty"`
}
@@ -301,13 +301,14 @@ func addUpConn(c *client, id string) (*upConnection, error) {
}
track := &upTrack{
track: remote,
+ label: u.labels[remote.ID()],
cache: packetcache.New(96),
rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate),
maxBitrate: ^uint64(0),
}
u.tracks = append(u.tracks, track)
- done := len(u.tracks) >= u.trackCount
+ done := u.complete()
if remote.Kind() == webrtc.RTPCodecTypeVideo {
atomic.AddUint32(&c.group.videoCount, 1)
}
@@ -830,28 +831,27 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *downTrack) {
}
}
-func countMediaStreams(data string) (int, error) {
- desc := sdp.NewJSEPSessionDescription(false)
- err := desc.Unmarshal(data)
+func negotiate(c *client, id string, down *downConnection) error {
+ offer, err := down.pc.CreateOffer(nil)
if err != nil {
- return 0, err
+ return err
}
- return len(desc.MediaDescriptions), nil
-}
-func negotiate(c *client, id string, pc *webrtc.PeerConnection) error {
- offer, err := pc.CreateOffer(nil)
+ err = down.pc.SetLocalDescription(offer)
if err != nil {
return err
}
- err = pc.SetLocalDescription(offer)
- if err != nil {
- return err
+
+ labels := make(map[string]string)
+ for _, t := range down.tracks {
+ labels[t.track.ID()] = t.remote.label
}
+
return c.write(clientMessage{
- Type: "offer",
- Id: id,
- Offer: &offer,
+ Type: "offer",
+ Id: id,
+ Offer: &offer,
+ Labels: labels,
})
}
@@ -867,7 +867,7 @@ func sendICE(c *client, id string, candidate *webrtc.ICECandidate) error {
})
}
-func gotOffer(c *client, offer webrtc.SessionDescription, id string) error {
+func gotOffer(c *client, id string, offer webrtc.SessionDescription, labels map[string]string) error {
var err error
up, ok := c.up[id]
if !ok {
@@ -879,12 +879,6 @@ func gotOffer(c *client, offer webrtc.SessionDescription, id string) error {
if c.username != "" {
up.label = c.username
}
- n, err := countMediaStreams(offer.SDP)
- if err != nil {
- log.Printf("Couldn't parse SDP: %v", err)
- n = 2
- }
- up.trackCount = n
err = up.pc.SetRemoteDescription(offer)
if err != nil {
return err
@@ -900,6 +894,8 @@ func gotOffer(c *client, offer webrtc.SessionDescription, id string) error {
return err
}
+ up.labels = labels
+
return c.write(clientMessage{
Type: "answer",
Id: id,
@@ -907,7 +903,7 @@ func gotOffer(c *client, offer webrtc.SessionDescription, id string) error {
})
}
-func gotAnswer(c *client, answer webrtc.SessionDescription, id string) error {
+func gotAnswer(c *client, id string, answer webrtc.SessionDescription) error {
conn := getDownConn(c, id)
if conn == nil {
return protocolError("unknown id in answer")
@@ -934,11 +930,7 @@ func gotICE(c *client, candidate *webrtc.ICECandidateInit, id string) error {
return pc.AddICECandidate(*candidate)
}
-func (c *client) setRequested(audio, video bool) error {
- if audio == c.requestedAudio && video == c.requestedVideo {
- return nil
- }
-
+func (c *client) setRequested(requested []string) error {
if c.down != nil {
for id := range c.down {
c.write(clientMessage{
@@ -949,8 +941,7 @@ func (c *client) setRequested(audio, video bool) error {
}
}
- c.requestedAudio = audio
- c.requestedVideo = video
+ c.requested = requested
go func() {
clients := c.group.getClients(c)
@@ -962,15 +953,13 @@ func (c *client) setRequested(audio, video bool) error {
return nil
}
-func (c *client) requested(kind webrtc.RTPCodecType) bool {
- switch kind {
- case webrtc.RTPCodecTypeAudio:
- return c.requestedAudio
- case webrtc.RTPCodecTypeVideo:
- return c.requestedVideo
- default:
- return false
+func (c *client) isRequested(label string) bool {
+ for _, r := range c.requested {
+ if label == r {
+ return true
+ }
}
+ return false
}
func pushTracks(c *client, conn *upConnection, tracks []*upTrack, done bool, label string) {
@@ -989,7 +978,7 @@ func clientLoop(c *client, conn *websocket.Conn) error {
go clientReader(conn, read, c.done)
defer func() {
- c.setRequested(false, false)
+ c.setRequested([]string{})
if c.up != nil {
for id := range c.up {
delUpConn(c, id)
@@ -1044,7 +1033,7 @@ func clientLoop(c *client, conn *websocket.Conn) error {
case addTrackAction:
var down *downConnection
var err error
- if c.requested(a.track.track.Kind()) {
+ if c.isRequested(a.track.label) {
down, _, err = addDownTrack(
c, a.remote.id, a.track,
a.remote)
@@ -1055,7 +1044,7 @@ func clientLoop(c *client, conn *websocket.Conn) error {
down = getDownConn(c, a.remote.id)
}
if a.done && down != nil {
- err = negotiate(c, a.remote.id, down.pc)
+ err = negotiate(c, a.remote.id, down)
if err != nil {
return err
}
@@ -1080,7 +1069,7 @@ func clientLoop(c *client, conn *websocket.Conn) error {
copy(tracks, u.tracks)
go pushTracks(
a.c, u, tracks,
- len(tracks) >= u.trackCount-1,
+ u.complete(),
u.label,
)
}
@@ -1141,15 +1130,7 @@ func clientLoop(c *client, conn *websocket.Conn) error {
func handleClientMessage(c *client, m clientMessage) error {
switch m.Type {
case "request":
- var audio, video bool
- for _, s := range m.Request {
- switch(s) {
- case "audio": audio = true
- case "video": video = true
- default: log.Printf("Unknown request %v", s)
- }
- }
- err := c.setRequested(audio, video)
+ err := c.setRequested(m.Request)
if err != nil {
return err
}
@@ -1164,7 +1145,7 @@ func handleClientMessage(c *client, m clientMessage) error {
if m.Offer == nil {
return protocolError("null offer")
}
- err := gotOffer(c, *m.Offer, m.Id)
+ err := gotOffer(c, m.Id, *m.Offer, m.Labels)
if err != nil {
return err
}
@@ -1172,7 +1153,7 @@ func handleClientMessage(c *client, m clientMessage) error {
if m.Answer == nil {
return protocolError("null answer")
}
- err := gotAnswer(c, *m.Answer, m.Id)
+ err := gotAnswer(c, m.Id, *m.Answer)
if err != nil {
return err
}
diff --git a/group.go b/group.go
index 521ea7c..25423c7 100644
--- a/group.go
+++ b/group.go
@@ -26,6 +26,7 @@ import (
type upTrack struct {
track *webrtc.Track
+ label string
rate *estimator.Estimator
cache *packetcache.Cache
jitter *jitter.Estimator
@@ -74,11 +75,27 @@ func (up *upTrack) hasRtcpFb(tpe, parameter string) bool {
}
type upConnection struct {
- id string
- label string
- pc *webrtc.PeerConnection
- trackCount int
- tracks []*upTrack
+ id string
+ label string
+ pc *webrtc.PeerConnection
+ tracks []*upTrack
+ labels map[string]string
+}
+
+func (up *upConnection) complete() bool {
+ for id, _ := range up.labels {
+ found := false
+ for _, t := range up.tracks {
+ if t.track.ID() == id {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return false
+ }
+ }
+ return true
}
type bitrate struct {
@@ -150,16 +167,15 @@ type downConnection struct {
}
type client struct {
- group *group
- id string
- username string
- permissions userPermission
- requestedAudio bool
- requestedVideo bool
- done chan struct{}
- writeCh chan interface{}
- writerDone chan struct{}
- actionCh chan interface{}
+ group *group
+ id string
+ username string
+ permissions userPermission
+ requested []string
+ done chan struct{}
+ writeCh chan interface{}
+ writerDone chan struct{}
+ actionCh chan interface{}
mu sync.Mutex
down map[string]*downConnection
diff --git a/static/sfu.css b/static/sfu.css
index c558e21..3547240 100644
--- a/static/sfu.css
+++ b/static/sfu.css
@@ -65,6 +65,11 @@ h1 {
margin-right: 0.4em;
}
+#requestselect {
+ width: 8em;
+ text-align-last: center;
+}
+
#main {
display: flex;
}
diff --git a/static/sfu.html b/static/sfu.html
index 37322b8..4363e77 100644
--- a/static/sfu.html
+++ b/static/sfu.html
@@ -47,8 +47,12 @@
-
-
+
+
diff --git a/static/sfu.js b/static/sfu.js
index ed0693a..65182f1 100644
--- a/static/sfu.js
+++ b/static/sfu.js
@@ -39,6 +39,7 @@ function Connection(id, pc) {
this.label = null;
this.pc = pc;
this.stream = null;
+ this.labels = {};
this.iceCandidates = [];
this.timers = [];
this.audioStats = {};
@@ -140,9 +141,9 @@ document.getElementById('sharebox').onchange = function(e) {
setShareMedia(this.checked);
};
-document.getElementById('requestbox').onchange = function(e) {
+document.getElementById('requestselect').onchange = function(e) {
e.preventDefault();
- sendRequest(this.checked);
+ sendRequest(this.value);
};
async function updateStats(conn, sender) {
@@ -312,6 +313,7 @@ async function setLocalMedia(setup) {
let c = up[localMediaId];
c.stream = stream;
stream.getTracks().forEach(t => {
+ c.labels[t.id] = t.kind
let sender = c.pc.addTrack(t, stream);
c.setInterval(() => {
updateStats(c, sender);
@@ -358,6 +360,7 @@ async function setShareMedia(setup) {
document.getElementById('sharebox').checked = false;
setShareMedia(false);
};
+ c.labels[t.id] = 'screenshare';
c.setInterval(() => {
updateStats(c, sender);
}, 2000);
@@ -485,7 +488,7 @@ function serverConnect() {
username: up.username,
password: up.password,
});
- sendRequest(document.getElementById('requestbox').checked);
+ sendRequest(document.getElementById('requestselect').value);
resolve();
};
socket.onclose = function(e) {
@@ -508,7 +511,7 @@ function serverConnect() {
let m = JSON.parse(e.data);
switch(m.type) {
case 'offer':
- gotOffer(m.id, m.offer);
+ gotOffer(m.id, m.labels, m.offer);
break;
case 'answer':
gotAnswer(m.id, m.answer);
@@ -556,14 +559,30 @@ function serverConnect() {
});
}
-function sendRequest(video) {
+function sendRequest(value) {
+ let request = [];
+ switch(value) {
+ case 'audio':
+ request = ['audio'];
+ break;
+ case 'screenshare':
+ request = ['audio', 'screenshare'];
+ break;
+ case 'everything':
+ request = ['audio', 'screenshare', 'video'];
+ break;
+ default:
+ console.error(`Uknown value ${value} in sendRequest`);
+ break;
+ }
+
send({
type: 'request',
- request: video ? ['audio', 'video'] : ['audio'],
+ request: request,
});
}
-async function gotOffer(id, offer) {
+async function gotOffer(id, labels, offer) {
let c = down[id];
if(!c) {
let pc = new RTCPeerConnection({
@@ -587,6 +606,8 @@ async function gotOffer(id, offer) {
};
}
+ c.labels = labels;
+
await c.pc.setRemoteDescription(offer);
await addIceCandidates(c);
let answer = await c.pc.createAnswer();
@@ -1007,6 +1028,7 @@ async function negotiate(id) {
send({
type: 'offer',
id: id,
+ labels: c.labels,
offer: offer,
});
}