From 5c2e5ee5c0a0aeaccf43960a3e60dda71c0c8bfe Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Sat, 9 Dec 2023 20:46:45 +0100 Subject: [PATCH] Add test for parsing bearer tokens. --- webserver/webserver_test.go | 29 +++++++++++++++++++++++++++++ webserver/whip.go | 8 ++++---- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/webserver/webserver_test.go b/webserver/webserver_test.go index 13a263e..faf4a5a 100644 --- a/webserver/webserver_test.go +++ b/webserver/webserver_test.go @@ -61,6 +61,35 @@ func TestParseWhip(t *testing.T) { } } +func TestParseBearerToken(t *testing.T) { + a := []struct{ a, b string }{ + {"", ""}, + {"foo", ""}, + {"foo bar", ""}, + {" foo bar", ""}, + {"foo bar ", ""}, + {"Bearer", ""}, + {"Bearer ", ""}, + {"Bearer foo", "foo"}, + {"bearer foo", "foo"}, + {" Bearer foo", "foo"}, + {"Bearer foo ", "foo"}, + {" Bearer foo ", "foo"}, + {"Bearer foo bar", ""}, + } + + for _, ab := range a { + t.Run(ab.a, func(t *testing.T) { + b := parseBearerToken(ab.a) + if b != ab.b { + t.Errorf("Bearer token %v, got %v, expected %v", + ab.a, b, ab.b, + ) + } + }) + } +} + func TestFormatICEServer(t *testing.T) { a := []struct { s webrtc.ICEServer diff --git a/webserver/whip.go b/webserver/whip.go index 7f833f1..74db7ff 100644 --- a/webserver/whip.go +++ b/webserver/whip.go @@ -52,8 +52,7 @@ func canPresent(perms []string) bool { return false } -func getBearerToken(r *http.Request) string { - auth := r.Header.Get("Authorization") +func parseBearerToken(auth string) string { auths := strings.Split(auth, ",") for _, a := range auths { a = strings.Trim(a, " \t") @@ -178,7 +177,8 @@ func whipEndpointHandler(w http.ResponseWriter, r *http.Request) { return } - token := getBearerToken(r) + token := parseBearerToken(r.Header.Get("Authorization")) + whip := "whip" creds := group.ClientCredentials{ Username: &whip, @@ -258,7 +258,7 @@ func whipResourceHandler(w http.ResponseWriter, r *http.Request) { } if t := c.Token(); t != "" { - token := getBearerToken(r) + token := parseBearerToken(r.Header.Get("Authorization")) if token != t { http.Error(w, "Forbidden", http.StatusForbidden) return