From cb7a087ea24e286a51efa99b4fa1be35101596ad Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Thu, 15 Aug 2024 00:41:27 +0200 Subject: [PATCH] Use mime.ParseMediaType instead of our version. --- webserver/api.go | 23 +++++++++++------------ webserver/api_test.go | 6 ++++-- webserver/webserver_test.go | 19 ------------------- 3 files changed, 15 insertions(+), 33 deletions(-) diff --git a/webserver/api.go b/webserver/api.go index 61461e4..983bc92 100644 --- a/webserver/api.go +++ b/webserver/api.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "io" + "mime" "net/http" "os" "strings" @@ -19,10 +20,6 @@ import ( "github.com/jech/galene/token" ) -func parseContentType(ctype string) string { - return strings.Trim(strings.Split(ctype, ";")[0], " ") -} - // checkAdmin checks whether the client authentifies as an administrator func checkAdmin(w http.ResponseWriter, r *http.Request) bool { username, password, ok := r.BasicAuth() @@ -73,8 +70,8 @@ func sendJSON(w http.ResponseWriter, r *http.Request, v any) { } func getText(w http.ResponseWriter, r *http.Request) ([]byte, bool) { - ctype := parseContentType(r.Header.Get("Content-Type")) - if !strings.EqualFold(ctype, "text/plain") { + ctype, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil || !strings.EqualFold(ctype, "text/plain") { w.Header().Set("Accept", "text/plain") http.Error(w, "unsupported content type", http.StatusUnsupportedMediaType) @@ -91,8 +88,8 @@ func getText(w http.ResponseWriter, r *http.Request) ([]byte, bool) { } func getJSON(w http.ResponseWriter, r *http.Request, v any) bool { - ctype := parseContentType(r.Header.Get("Content-Type")) - if !strings.EqualFold(ctype, "application/json") { + ctype, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil || !strings.EqualFold(ctype, "application/json") { w.Header().Set("Accept", "application/json") http.Error(w, "unsupported content type", http.StatusUnsupportedMediaType) @@ -100,7 +97,7 @@ func getJSON(w http.ResponseWriter, r *http.Request, v any) bool { } d := json.NewDecoder(r.Body) - err := d.Decode(v) + err = d.Decode(v) if err != nil { httpError(w, err) return true @@ -458,8 +455,10 @@ func keysHandler(w http.ResponseWriter, r *http.Request, g string) { if r.Method == "PUT" { // cannot use getJSON due to the weird content-type - ctype := parseContentType(r.Header.Get("Content-Type")) - if !strings.EqualFold(ctype, "application/jwk-set+json") { + ctype, _, err := + mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil || + !strings.EqualFold(ctype, "application/jwk-set+json") { w.Header().Set("Accept", "application/jwk-set+json") http.Error(w, "unsupported content type", http.StatusUnsupportedMediaType) @@ -467,7 +466,7 @@ func keysHandler(w http.ResponseWriter, r *http.Request, g string) { } d := json.NewDecoder(r.Body) var keys jwkset - err := d.Decode(&keys) + err = d.Decode(&keys) if err != nil { httpError(w, err) return diff --git a/webserver/api_test.go b/webserver/api_test.go index e26c3f1..8bc6616 100644 --- a/webserver/api_test.go +++ b/webserver/api_test.go @@ -3,6 +3,7 @@ package webserver import ( "errors" "fmt" + "mime" "os" "reflect" "strings" @@ -104,8 +105,9 @@ func TestApi(t *testing.T) { if resp.StatusCode != http.StatusOK { return fmt.Errorf("Status is %v", resp.StatusCode) } - ctype := parseContentType(resp.Header.Get("Content-Type")) - if !strings.EqualFold(ctype, "application/json") { + ctype, _, err := + mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil || !strings.EqualFold(ctype, "application/json") { return errors.New("Unexpected content-type") } d := json.NewDecoder(resp.Body) diff --git a/webserver/webserver_test.go b/webserver/webserver_test.go index dbcab15..007816c 100644 --- a/webserver/webserver_test.go +++ b/webserver/webserver_test.go @@ -124,25 +124,6 @@ func TestParseSplit(t *testing.T) { } } -func TestParseContentType(t *testing.T) { - a := []struct{ a, b string }{ - {"", ""}, - {"text/plain", "text/plain"}, - {"text/plain;charset=utf-8", "text/plain"}, - {"text/plain; charset=utf-8", "text/plain"}, - {"text/plain ; charset=utf-8", "text/plain"}, - } - - for _, ab := range a { - b := parseContentType(ab.a) - if b != ab.b { - t.Errorf("Content type %v, got %v, expected %v", - ab.a, b, ab.b, - ) - } - } -} - func TestParseBearerToken(t *testing.T) { a := []struct{ a, b string }{ {"", ""},