diff --git a/diskwriter/diskwriter.go b/diskwriter/diskwriter.go index c5816e4..80326d7 100644 --- a/diskwriter/diskwriter.go +++ b/diskwriter/diskwriter.go @@ -24,6 +24,12 @@ import ( gcodecs "github.com/jech/galene/codecs" "github.com/jech/galene/conn" "github.com/jech/galene/group" + "github.com/jech/galene/rtptime" +) + +const ( + audioMaxLate = 32 + videoMaxLate = 256 ) var Directory string @@ -170,6 +176,8 @@ type diskConn struct { tracks []*diskTrack width, height uint32 lastWarning time.Time + originLocal time.Time + originRemote uint64 } // called locked @@ -200,6 +208,9 @@ func (conn *diskConn) open(extension string) error { // called locked func (conn *diskConn) close() []*diskTrack { + conn.originLocal = time.Time{} + conn.originRemote = 0 + tracks := make([]*diskTrack, 0, len(conn.tracks)) for _, t := range conn.tracks { t.writeBuffered(true) @@ -207,6 +218,7 @@ func (conn *diskConn) close() []*diskTrack { t.writer.Close() t.writer = nil } + t.origin = none tracks = append(tracks, t) } conn.file = nil @@ -282,7 +294,11 @@ type diskTrack struct { writer mkvcore.BlockWriteCloser builder *samplebuilder.SampleBuilder lastSeqno maybeUint32 - origin maybeUint32 + + origin maybeUint32 + + remoteNTP uint64 + remoteRTP uint32 kfRequested time.Time lastKf time.Time @@ -339,21 +355,25 @@ func newDiskConn(client *Client, directory string, up conn.Up, remoteTracks []co codec := remote.Codec() if strings.EqualFold(codec.MimeType, "audio/opus") { builder = samplebuilder.New( - 16, &codecs.OpusPacket{}, codec.ClockRate, + audioMaxLate, + &codecs.OpusPacket{}, codec.ClockRate, ) } else if strings.EqualFold(codec.MimeType, "video/vp8") { builder = samplebuilder.New( - 256, &codecs.VP8Packet{}, codec.ClockRate, + videoMaxLate, + &codecs.VP8Packet{}, codec.ClockRate, ) conn.hasVideo = true } else if strings.EqualFold(codec.MimeType, "video/vp9") { builder = samplebuilder.New( - 256, &codecs.VP9Packet{}, codec.ClockRate, + videoMaxLate, &codecs.VP9Packet{}, + codec.ClockRate, ) conn.hasVideo = true } else if strings.EqualFold(codec.MimeType, "video/h264") { builder = samplebuilder.New( - 256, &codecs.H264Packet{}, codec.ClockRate, + videoMaxLate, &codecs.H264Packet{}, + codec.ClockRate, ) conn.hasVideo = true } else { @@ -387,9 +407,6 @@ func newDiskConn(client *Client, directory string, up conn.Up, remoteTracks []co return &conn, nil } -func (t *diskTrack) SetTimeOffset(ntp uint64, rtp uint32) { -} - func (t *diskTrack) SetCname(string) { } @@ -475,18 +492,33 @@ func (t *diskTrack) writeRTP(p *rtp.Packet) error { if kf { t.savedKf = p t.lastKf = time.Now() + if !valid(t.origin) { + t.setOrigin( + p.Timestamp, time.Now(), + t.remote.Codec().ClockRate, + ) + } } else if time.Since(t.lastKf) > 4*time.Second { requestKeyframe(t) } } + if !valid(t.origin) { + if !t.conn.hasVideo || !t.conn.originLocal.Equal(time.Time{}) { + t.setOrigin( + p.Timestamp, time.Now(), + t.remote.Codec().ClockRate, + ) + } + } + t.builder.Push(p) return t.writeBuffered(false) } -// writeBuffered writes any buffered samples to disk. If force is true, -// then samples will be flushed even if they are preceded by incomplete +// writeBuffered writes buffered samples to disk. If force is true, then +// samples will be flushed even if they are preceded by incomplete // samples. func (t *diskTrack) writeBuffered(force bool) error { codec := t.remote.Codec().MimeType @@ -503,6 +535,16 @@ func (t *diskTrack) writeBuffered(force bool) error { return nil } + if valid(t.origin) && int32(ts-value(t.origin)) < 0 { + if value(t.origin)-ts < 0x10000 { + // late packet before origin, drop + continue + } + // we've gone around 2^31 timestamps, force + // creating a new file to avoid wraparound + t.conn.close() + } + var keyframe bool if len(codec) > 6 && strings.EqualFold(codec[:6], "video/") { if t.savedKf == nil { @@ -512,11 +554,10 @@ func (t *diskTrack) writeBuffered(force bool) error { } if keyframe { - err := t.conn.initWriter( - gcodecs.KeyframeDimensions( - codec, t.savedKf, - ), + w, h := gcodecs.KeyframeDimensions( + codec, t.savedKf, ) + err := t.conn.initWriter(w, h, t, ts) if err != nil { t.conn.warn( "Write to disk " + err.Error(), @@ -528,7 +569,7 @@ func (t *diskTrack) writeBuffered(force bool) error { keyframe = true if t.writer == nil { if !t.conn.hasVideo { - err := t.conn.initWriter(0, 0) + err := t.conn.initWriter(0, 0, t, ts) if err != nil { t.conn.warn( "Write to disk " + @@ -545,11 +586,12 @@ func (t *diskTrack) writeBuffered(force bool) error { } if !valid(t.origin) { - t.origin = some(ts) + log.Println("Invalid origin") + return nil } - ts -= value(t.origin) - tm := ts / (t.remote.Codec().ClockRate / 1000) + tm := (ts - value(t.origin)) / + (t.remote.Codec().ClockRate / 1000) _, err := t.writer.Write(keyframe, int64(tm), sample.Data) if err != nil { return err @@ -557,8 +599,114 @@ func (t *diskTrack) writeBuffered(force bool) error { } } +// setOrigin sets the origin of track t after receiving a packet with +// timestamp ts at local time now. // called locked -func (conn *diskConn) initWriter(width, height uint32) error { +func (t *diskTrack) setOrigin(ts uint32, now time.Time, clockrate uint32) { + sub := func(a, b uint32, hz uint32) time.Duration { + return rtptime.ToDuration(int64(int32(a-b)), hz) + } + + if t.conn.originLocal.Equal(time.Time{}) { + t.origin = some(ts) + t.conn.originLocal = now + if t.remoteNTP != 0 { + remote := rtptime.NTPToTime(t.remoteNTP).Add( + sub(ts, t.remoteRTP, clockrate), + ) + t.conn.originRemote = rtptime.TimeToNTP(remote) + } else { + t.conn.originRemote = 0 + } + } else if t.conn.originRemote != 0 && t.remoteNTP != 0 { + remote := rtptime.NTPToTime(t.remoteNTP).Add( + sub(ts, t.remoteRTP, clockrate), + ) + origin := rtptime.NTPToTime(t.conn.originRemote) + delta := rtptime.FromDuration(remote.Sub(origin), clockrate) + t.origin = some(ts - uint32(delta)) + } else { + d := now.Sub(t.conn.originLocal) + delta := rtptime.FromDuration(d, clockrate) + t.origin = some(ts - uint32(delta)) + if t.remoteNTP != 0 { + remote := rtptime.NTPToTime(t.remoteNTP).Add( + sub(ts, t.remoteRTP, clockrate), + ) + t.conn.originRemote = rtptime.TimeToNTP( + remote.Add(-d), + ) + } + } +} + +// SetTimeOffset adjusts the origin of track t given remote sync information. +func (t *diskTrack) SetTimeOffset(ntp uint64, rtp uint32) { + t.conn.mu.Lock() + defer t.conn.mu.Unlock() + t.setTimeOffset(ntp, rtp, t.remote.Codec().ClockRate) +} + +// called locked +func (t *diskTrack) setTimeOffset(ntp uint64, rtp uint32, clockrate uint32) { + if valid(t.origin) { + local := rtptime.ToDuration( + int64(int32(rtp-value(t.origin))), clockrate, + ) + if t.conn.originRemote == 0 { + t.conn.originRemote = + rtptime.TimeToNTP( + rtptime.NTPToTime(ntp).Add(-local)) + } else { + remote := rtptime.NTPToTime(ntp).Sub( + rtptime.NTPToTime(t.conn.originRemote)) + delta := rtptime.FromDuration(remote-local, clockrate) + t.origin = some(value(t.origin) - uint32(delta)) + } + } + + t.remoteNTP = ntp + t.remoteRTP = rtp +} + +// adjustOrigin adjusts all origin-related fields of all tracks so that +// the origin of track t is equal to ts. +// Called locked. +func (t *diskTrack) adjustOrigin(ts uint32) { + if !valid(t.origin) || value(t.origin) == ts { + return + } + + offset := rtptime.ToDuration( + int64(int32(ts-value(t.origin))), t.remote.Codec().ClockRate, + ) + + if !t.conn.originLocal.Equal(time.Time{}) { + t.conn.originLocal = t.conn.originLocal.Add(offset) + } + if t.conn.originRemote != 0 { + t.conn.originRemote = + rtptime.TimeToNTP( + rtptime.NTPToTime( + t.conn.originRemote, + ).Add(offset), + ) + } + + for _, tt := range t.conn.tracks { + if valid(tt.origin) { + tt.origin = some(value(tt.origin) + + uint32(rtptime.FromDuration( + offset, + tt.remote.Codec().ClockRate, + )), + ) + } + } +} + +// called locked +func (conn *diskConn) initWriter(width, height uint32, track *diskTrack, ts uint32) error { if conn.file != nil { if width == conn.width && height == conn.height { return nil @@ -637,6 +785,10 @@ func (conn *diskConn) initWriter(width, height uint32) error { header = &h } + if track != nil { + track.adjustOrigin(ts) + } + err := conn.open(extension) if err != nil { return err @@ -644,8 +796,8 @@ func (conn *diskConn) initWriter(width, height uint32) error { interceptor, err := mkvcore.NewMultiTrackBlockSorter( // must be larger than the samplebuilder's MaxLate. - mkvcore.WithMaxDelayedPackets(384), - mkvcore.WithSortRule(mkvcore.BlockSorterDropOutdated), + mkvcore.WithMaxDelayedPackets(videoMaxLate+16), + mkvcore.WithSortRule(mkvcore.BlockSorterWriteOutdated), ) if err != nil { conn.file.Close() diff --git a/diskwriter/diskwriter_test.go b/diskwriter/diskwriter_test.go new file mode 100644 index 0000000..f304023 --- /dev/null +++ b/diskwriter/diskwriter_test.go @@ -0,0 +1,157 @@ +package diskwriter + +import ( + "testing" + "time" + + "github.com/jech/galene/rtptime" +) + +func TestAdjustOriginLocalNow(t *testing.T) { + now := time.Now() + + c := &diskConn{ + tracks: []*diskTrack{ + &diskTrack{}, + }, + } + for _, t := range c.tracks { + t.conn = c + } + c.tracks[0].setOrigin(132, now, 100) + + if !c.originLocal.Equal(now) { + t.Errorf("Expected %v, got %v", now, c.originLocal) + } + + if c.originRemote != 0 { + t.Errorf("Expected 0, got %v", c.originRemote) + } + + if c.tracks[0].origin != some(132) { + t.Errorf("Expected 132, got %v", value(c.tracks[0].origin)) + } +} + +func TestAdjustOriginLocalEarlier(t *testing.T) { + now := time.Now() + earlier := now.Add(-time.Second) + + c := &diskConn{ + originLocal: earlier, + tracks: []*diskTrack{ + &diskTrack{}, + }, + } + for _, t := range c.tracks { + t.conn = c + } + c.tracks[0].setOrigin(132, now, 100) + + if !c.originLocal.Equal(earlier) { + t.Errorf("Expected %v, got %v", earlier, c.originLocal) + } + + if c.originRemote != 0 { + t.Errorf("Expected 0, got %v", c.originRemote) + } + + if c.tracks[0].origin != some(32) { + t.Errorf("Expected 32, got %v", value(c.tracks[0].origin)) + } +} + +func TestAdjustOriginLocalLater(t *testing.T) { + now := time.Now() + later := now.Add(time.Second) + + c := &diskConn{ + originLocal: later, + tracks: []*diskTrack{ + &diskTrack{}, + }, + } + for _, t := range c.tracks { + t.conn = c + } + c.tracks[0].setOrigin(32, now, 100) + + if !c.originLocal.Equal(later) { + t.Errorf("Expected %v, got %v", later, c.originLocal) + } + + if c.originRemote != 0 { + t.Errorf("Expected 0, got %v", c.originRemote) + } + + if c.tracks[0].origin != some(132) { + t.Errorf("Expected 132, got %v", value(c.tracks[0].origin)) + } +} + +func TestAdjustOriginRemote(t *testing.T) { + now := time.Now() + earlier := now.Add(-time.Second) + + c := &diskConn{ + tracks: []*diskTrack{ + &diskTrack{ + remoteNTP: rtptime.TimeToNTP(earlier), + remoteRTP: 32, + }, + }, + } + for _, t := range c.tracks { + t.conn = c + } + c.tracks[0].setOrigin(132, now, 100) + + if !c.originLocal.Equal(now) { + t.Errorf("Expected %v, got %v", now, c.originLocal) + } + + d := now.Sub(rtptime.NTPToTime(c.originRemote)) + if d < -time.Millisecond || d > time.Millisecond { + t.Errorf("Expected %v, got %v (delta %v)", + rtptime.TimeToNTP(now), + c.originRemote, d) + } + + if c.tracks[0].origin != some(132) { + t.Errorf("Expected 132, got %v", value(c.tracks[0].origin)) + } +} + +func TestAdjustOriginLocalRemote(t *testing.T) { + now := time.Now() + earlier := now.Add(-time.Second) + + c := &diskConn{ + tracks: []*diskTrack{ + &diskTrack{}, + }, + } + for _, t := range c.tracks { + t.conn = c + } + c.tracks[0].setOrigin(132, now, 100) + + c.tracks[0].setTimeOffset(rtptime.TimeToNTP(earlier), 32, 100) + + c.tracks[0].setOrigin(132, now, 100) + + if !c.originLocal.Equal(now) { + t.Errorf("Expected %v, got %v", now, c.originLocal) + } + + d := now.Sub(rtptime.NTPToTime(c.originRemote)) + if d < -time.Millisecond || d > time.Millisecond { + t.Errorf("Expected %v, got %v (delta %v)", + rtptime.TimeToNTP(now), + c.originRemote, d) + } + + if c.tracks[0].origin != some(132) { + t.Errorf("Expected 132, got %v", value(c.tracks[0].origin)) + } +}