1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-22 08:35:57 +01:00

Split out SDP fragment parsing, add test.

This commit is contained in:
Juliusz Chroboczek 2024-09-30 00:23:36 +02:00
parent 45bbb138c6
commit df274ad6ea
2 changed files with 51 additions and 14 deletions

View file

@ -1,6 +1,7 @@
package webserver package webserver
import ( import (
"bufio"
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
@ -345,23 +346,25 @@ func whipResourceHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, sdpLimit)) err = parseSDPFrag(
http.MaxBytesReader(w, r.Body, sdpLimit),
c.GotICECandidate,
)
if err != nil { if err != nil {
httpError(w, err) log.Printf("WHIP trickle ICE: %v", err)
http.Error(w, "bad request", http.StatusBadRequest)
return return
} }
w.WriteHeader(http.StatusNoContent)
}
if len(body) < 2 { // RFC 8840
http.Error(w, "SDP truncated", http.StatusBadRequest) func parseSDPFrag(r io.Reader, f func(webrtc.ICECandidateInit) error) error {
return scanner := bufio.NewScanner(r)
}
// RFC 8840
lines := bytes.Split(body, []byte{'\n'})
mLineIndex := -1 mLineIndex := -1
var mid, ufrag []byte var mid, ufrag []byte
for _, l := range lines { for scanner.Scan() {
l = bytes.TrimRight(l, " \r") l := scanner.Bytes()
if bytes.HasPrefix(l, []byte("a=ice-ufrag:")) { if bytes.HasPrefix(l, []byte("a=ice-ufrag:")) {
ufrag = l[len("a=ice-ufrag:"):] ufrag = l[len("a=ice-ufrag:"):]
} else if bytes.HasPrefix(l, []byte("m=")) { } else if bytes.HasPrefix(l, []byte("m=")) {
@ -385,12 +388,11 @@ func whipResourceHandler(w http.ResponseWriter, r *http.Request) {
s := string(ufrag) s := string(ufrag)
init.UsernameFragment = &s init.UsernameFragment = &s
} }
err := c.GotICECandidate(init) err := f(init)
if err != nil { if err != nil {
log.Printf("WHIP candidate: %v", err) log.Printf("WHIP candidate: %v", err)
} }
} }
} }
w.WriteHeader(http.StatusNoContent) return nil
return
} }

35
webserver/whip_test.go Normal file
View file

@ -0,0 +1,35 @@
package webserver
import (
"strings"
"testing"
"github.com/pion/webrtc/v3"
)
func TestParseSDPFrag(t *testing.T) {
sdp := `a=ice-ufrag:FZ0m
a=ice-pwd:NRT+gj1EhsEwMm9MA7ljzBRy
m=audio 9 UDP/TLS/RTP/SAVPF 0
a=mid:0
a=candidate:2930517337 1 udp 2113937151 1eaafdf1-4127-499f-90d4-8c35ea49d5e6.local 44360 typ host generation 0 ufrag FZ0m network-cost 999
2024/09/30 00:07:41 {candidate:2930517337 1 udp 2113937151 1eaafdf1-4127-499f-90d4-8c35ea49d5e6.local 44360 typ host generation 0 ufrag FZ0m network-cost 999 0xc00062a580 0xc000620288 0xc00062a590}
a=end-of-candidates`
r := strings.NewReader(sdp)
candidates := []webrtc.ICECandidateInit(nil)
err := parseSDPFrag(r, func(c webrtc.ICECandidateInit) error {
candidates = append(candidates, c)
return nil
})
if err != nil {
t.Errorf("parseSDPFrag: %v", err)
}
if len(candidates) != 1 {
t.Errorf("Expected 1, got %v", candidates)
}
if *candidates[0].SDPMLineIndex != 0 ||
*candidates[0].SDPMid != "0" ||
*candidates[0].UsernameFragment != "FZ0m" {
t.Errorf("Got %v", candidates[0])
}
}