diff --git a/client.go b/client.go index 2bdaaf2..6d32472 100644 --- a/client.go +++ b/client.go @@ -654,7 +654,7 @@ func getDownConn(c *client, id string) *downConnection { return conn } -func getConn(c *client, id string) connection { +func getConn(c *client, id string) iceConnection { up := getUpConn(c, id) if up != nil { return up @@ -1053,6 +1053,11 @@ func gotOffer(c *client, id string, offer webrtc.SessionDescription, labels map[ up.labels = labels + err = up.flushICECandidates() + if err != nil { + log.Printf("ICE: %v", err) + } + return c.write(clientMessage{ Type: "answer", Id: id, @@ -1061,17 +1066,22 @@ func gotOffer(c *client, id string, offer webrtc.SessionDescription, labels map[ } func gotAnswer(c *client, id string, answer webrtc.SessionDescription) error { - conn := getDownConn(c, id) - if conn == nil { + down := getDownConn(c, id) + if down == nil { return protocolError("unknown id in answer") } - err := conn.pc.SetRemoteDescription(answer) + err := down.pc.SetRemoteDescription(answer) if err != nil { return err } - for _, t := range conn.tracks { - activateDownTrack(conn, t) + err = down.flushICECandidates() + if err != nil { + log.Printf("ICE: %v", err) + } + + for _, t := range down.tracks { + activateDownTrack(down, t) } return nil } @@ -1081,7 +1091,7 @@ func gotICE(c *client, candidate *webrtc.ICECandidateInit, id string) error { if conn == nil { return errors.New("unknown id in ICE") } - return conn.getPC().AddICECandidate(*candidate) + return conn.addICECandidate(candidate) } func (c *client) setRequested(requested map[string]uint32) error { diff --git a/conn.go b/conn.go index 2223da2..f0b5611 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,7 @@ package main import ( + "errors" "sync" "sync/atomic" @@ -16,8 +17,9 @@ import ( "github.com/pion/webrtc/v2" ) -type connection interface { - getPC() *webrtc.PeerConnection +type iceConnection interface { + addICECandidate(candidate *webrtc.ICECandidateInit) error + flushICECandidates() error } type upTrack struct { @@ -30,6 +32,7 @@ type upTrack struct { lastPLI uint64 lastSenderReport uint32 lastSenderReportTime uint32 + iceCandidates []*webrtc.ICECandidateInit localCh chan struct{} // signals that local has changed writerDone chan struct{} // closed when the loop dies @@ -103,6 +106,35 @@ func (up *upConnection) getPC() *webrtc.PeerConnection { return up.pc } +func (up *upConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error { + if up.pc.RemoteDescription() != nil { + return up.pc.AddICECandidate(*candidate) + } + up.iceCandidates = append(up.iceCandidates, candidate) + return nil +} + +func flushICECandidates(pc *webrtc.PeerConnection, candidates []*webrtc.ICECandidateInit) error { + if pc.RemoteDescription() == nil { + return errors.New("flushICECandidates called in bad state") + } + + var err error + for _, candidate := range candidates { + err2 := pc.AddICECandidate(*candidate) + if err == nil { + err = err2 + } + } + return err +} + +func (up *upConnection) flushICECandidates() error { + err := flushICECandidates(up.pc, up.iceCandidates) + up.iceCandidates = nil + return err +} + func getUpMid(pc *webrtc.PeerConnection, track *webrtc.Track) string { for _, t := range pc.GetTransceivers() { if t.Receiver() != nil && t.Receiver().Track() == track { @@ -201,3 +233,17 @@ type downConnection struct { func (down *downConnection) getPC() *webrtc.PeerConnection { return down.pc } + +func (down *downConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error { + if down.pc.RemoteDescription() != nil { + return down.pc.AddICECandidate(*candidate) + } + down.iceCandidates = append(down.iceCandidates, candidate) + return nil +} + +func (down *downConnection) flushICECandidates() error { + err := flushICECandidates(down.pc, down.iceCandidates) + down.iceCandidates = nil + return err +}