diff --git a/conn/conn.go b/conn/conn.go index c4a4c2c..d18ef66 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -27,7 +27,7 @@ type UpTrack interface { Codec() webrtc.RTPCodecCapability // get a recent packet. Returns 0 if the packet is not in cache. GetRTP(seqno uint16, result []byte) uint16 - Nack(conn Up, seqnos []uint16) error + Nack(seqnos []uint16) error RequestKeyframe() error } diff --git a/diskwriter/diskwriter.go b/diskwriter/diskwriter.go index 38c1348..8220051 100644 --- a/diskwriter/diskwriter.go +++ b/diskwriter/diskwriter.go @@ -510,7 +510,7 @@ func (t *diskTrack) Write(buf []byte) (int, error) { } } if len(nacks) > 0 { - t.remote.Nack(t.conn.remote, nacks) + t.remote.Nack(nacks) } } } diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index 6fedd55..06b9e96 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -80,6 +80,7 @@ type downTrackAtomics struct { type rtpDownTrack struct { track *webrtc.TrackLocalStaticRTP sender *webrtc.RTPSender + conn *rtpDownConnection remote conn.UpTrack ssrc webrtc.SSRC packetmap packetmap.Map @@ -351,6 +352,7 @@ func (down *rtpDownConnection) flushICECandidates() error { type rtpUpTrack struct { track *webrtc.TrackRemote + conn *rtpUpConnection rate *estimator.Estimator cache *packetcache.Cache jitter *jitter.Estimator @@ -616,6 +618,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon track := &rtpUpTrack{ track: remote, + conn: up, cache: packetcache.New(minPacketCache(remote)), rate: estimator.New(time.Second), jitter: jitter.New(remote.Codec().ClockRate), @@ -625,7 +628,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon up.tracks = append(up.tracks, track) - go readLoop(up, track) + go readLoop(track) go rtcpUpListener(up, track, receiver) @@ -643,11 +646,11 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon var ErrUnsupportedFeedback = errors.New("unsupported feedback type") var ErrRateLimited = errors.New("rate limited") -func (up *rtpUpConnection) sendPLI(track *rtpUpTrack) error { +func (track *rtpUpTrack) sendPLI() error { if !track.hasRtcpFb("nack", "pli") { return ErrUnsupportedFeedback } - return sendPLI(up.pc, track.track.SSRC()) + return sendPLI(track.conn.pc, track.track.SSRC()) } func sendPLI(pc *webrtc.PeerConnection, ssrc webrtc.SSRC) error { @@ -656,12 +659,12 @@ func sendPLI(pc *webrtc.PeerConnection, ssrc webrtc.SSRC) error { }) } -func (up *rtpUpConnection) sendNACK(track *rtpUpTrack, first uint16, bitmap uint16) error { +func (track *rtpUpTrack) sendNACK(first uint16, bitmap uint16) error { if !track.hasRtcpFb("nack", "") { return ErrUnsupportedFeedback } - err := sendNACKs(up.pc, track.track.SSRC(), + err := sendNACKs(track.conn.pc, track.track.SSRC(), []rtcp.NackPair{{first, rtcp.PacketBitmap(bitmap)}}, ) if err == nil { @@ -670,7 +673,7 @@ func (up *rtpUpConnection) sendNACK(track *rtpUpTrack, first uint16, bitmap uint return err } -func (up *rtpUpConnection) sendNACKs(track *rtpUpTrack, seqnos []uint16) error { +func (track *rtpUpTrack) sendNACKs(seqnos []uint16) error { count := len(seqnos) if count == 0 { return nil @@ -691,7 +694,7 @@ func (up *rtpUpConnection) sendNACKs(track *rtpUpTrack, seqnos []uint16) error { f, b, seqnos = packetcache.ToBitmap(seqnos) nacks = append(nacks, rtcp.NackPair{f, rtcp.PacketBitmap(b)}) } - err := sendNACKs(up.pc, track.track.SSRC(), nacks) + err := sendNACKs(track.conn.pc, track.track.SSRC(), nacks) if err == nil { track.cache.Expect(count) } @@ -708,7 +711,7 @@ func sendNACKs(pc *webrtc.PeerConnection, ssrc webrtc.SSRC, nacks []rtcp.NackPai return pc.WriteRTCP([]rtcp.Packet{packet}) } -func gotNACK(conn *rtpDownConnection, track *rtpDownTrack, p *rtcp.TransportLayerNack) { +func gotNACK(track *rtpDownTrack, p *rtcp.TransportLayerNack) { var unhandled []uint16 buf := make([]byte, packetcache.BufSize) for _, nack := range p.Nacks { @@ -734,10 +737,10 @@ func gotNACK(conn *rtpDownConnection, track *rtpDownTrack, p *rtcp.TransportLaye return } - track.remote.Nack(conn.remote, unhandled) + track.remote.Nack(unhandled) } -func (track *rtpUpTrack) Nack(conn conn.Up, nacks []uint16) error { +func (track *rtpUpTrack) Nack(nacks []uint16) error { track.mu.Lock() defer track.mu.Unlock() @@ -754,12 +757,7 @@ outer: } if doit { - up, ok := conn.(*rtpUpConnection) - if !ok { - log.Printf("Nack: unexpected type %T", conn) - return errors.New("unexpected connection type") - } - go nackWriter(up, track) + go nackWriter(track) } return nil } @@ -1088,7 +1086,7 @@ func (track *rtpDownTrack) updateRate(loss uint8, now uint64) { track.maxBitrate.Set(rate, now) } -func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) { +func rtcpDownListener(track *rtpDownTrack, s *webrtc.RTPSender) { lastFirSeqno := uint8(0) buf := make([]byte, 1500) @@ -1149,7 +1147,7 @@ func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RT } } case *rtcp.TransportLayerNack: - gotNACK(conn, track, p) + gotNACK(track, p) } } if adjust { diff --git a/rtpconn/rtpreader.go b/rtpconn/rtpreader.go index ad0ef81..c4a6886 100644 --- a/rtpconn/rtpreader.go +++ b/rtpconn/rtpreader.go @@ -12,8 +12,8 @@ import ( "github.com/jech/galene/rtptime" ) -func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { - writers := rtpWriterPool{conn: conn, track: track} +func readLoop(track *rtpUpTrack) { + writers := rtpWriterPool{track: track} defer func() { writers.close() close(track.readerDone) @@ -118,7 +118,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { packet.SequenceNumber - unnacked, ) if found && sendNACK { - err := conn.sendNACK(track, first, bitmap) + err := track.sendNACK(first, bitmap) if err != nil { log.Printf("%v", err) } @@ -136,7 +136,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { now := time.Now() if kfNeeded && now.Sub(kfRequested) > time.Second/2 { if sendPLI { - err := conn.sendPLI(track) + err := track.sendPLI() if err != nil { log.Printf("sendPLI: %v", err) kfNeeded = false diff --git a/rtpconn/rtpwriter.go b/rtpconn/rtpwriter.go index 81c47e6..a57a9e0 100644 --- a/rtpconn/rtpwriter.go +++ b/rtpconn/rtpwriter.go @@ -21,7 +21,6 @@ type packetIndex struct { // An rtpWriterPool is a set of rtpWriters type rtpWriterPool struct { - conn *rtpUpConnection track *rtpUpTrack writers []*rtpWriter count int @@ -72,7 +71,7 @@ func (wp *rtpWriterPool) add(track conn.DownTrack, add bool) error { } if add { - writer := newRtpWriter(wp.conn, wp.track) + writer := newRtpWriter(wp.track) wp.writers = append(wp.writers, writer) err := writer.add(track, true, n) if err == nil { @@ -181,13 +180,13 @@ type rtpWriter struct { drop int } -func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter { +func newRtpWriter(track *rtpUpTrack) *rtpWriter { writer := &rtpWriter{ ch: make(chan packetIndex, 32), done: make(chan struct{}), action: make(chan writerAction, 1), } - go rtpWriterLoop(writer, conn, track) + go rtpWriterLoop(writer, track) return writer } @@ -223,7 +222,7 @@ func sendKeyframe(kf []uint16, track conn.DownTrack, cache *packetcache.Cache) { } // rtpWriterLoop is the main loop of an rtpWriter. -func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { +func rtpWriterLoop(writer *rtpWriter, track *rtpUpTrack) { defer close(writer.done) buf := make([]byte, packetcache.BufSize) @@ -314,7 +313,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { // nackWriter is called when bufferedNACKs becomes non-empty. It decides // which nacks to ship out. -func nackWriter(conn *rtpUpConnection, track *rtpUpTrack) { +func nackWriter(track *rtpUpTrack) { // a client might send us a NACK for a packet that has already // been nacked by the reader loop. Give recovery a chance. time.Sleep(50 * time.Millisecond) @@ -366,6 +365,6 @@ func nackWriter(conn *rtpUpConnection, track *rtpUpTrack) { }) if len(nacks) > 0 { - conn.sendNACKs(track, nacks) + track.sendNACKs(nacks) } } diff --git a/rtpconn/webclient.go b/rtpconn/webclient.go index 2921c70..19d43a0 100644 --- a/rtpconn/webclient.go +++ b/rtpconn/webclient.go @@ -399,6 +399,7 @@ func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remo track: local, sender: sender, ssrc: parms.Encodings[0].SSRC, + conn: conn, remote: remoteTrack, maxBitrate: new(bitrate), maxREMBBitrate: new(bitrate), @@ -409,7 +410,7 @@ func addDownTrackUnlocked(conn *rtpDownConnection, remoteTrack *rtpUpTrack, remo conn.tracks = append(conn.tracks, track) - go rtcpDownListener(conn, track, sender) + go rtcpDownListener(track, sender) return nil }