diff --git a/webserver/precondition.go b/webserver/precondition.go new file mode 100644 index 0000000..d13da32 --- /dev/null +++ b/webserver/precondition.go @@ -0,0 +1,110 @@ +package webserver + +import ( + "net/http" + "strings" +) + +// This is partly based on the Go standard library file net/http/fs.go. + +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = strings.TrimLeft(s, " \t\n\r") + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return "", "" + } + } + return "", "" +} + +// etagMatch returns whether the given eTag matches an IM/INM header value. +// The empty string in etag is interpreted as a non-existent object. +func etagMatch(etag, header string) bool { + if header == "" { + return false + } + if header == etag { + return true + } + + for { + header = strings.TrimLeft(header, " \t\n\r") + if len(header) == 0 { + break + } + if header[0] == ',' { + header = header[1:] + continue + } + if header[0] == '*' { + return etag != "" + } + e, remain := scanETag(header) + if e == "" { + break + } + if e == etag { + return true + } + header = remain + } + + return false +} + +func writeNotModified(w http.ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + delete(h, "Content-Encoding") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(http.StatusNotModified) +} + +// checkPreconditions evaluates request preconditions. +// It interprets an empty etag as a non-existent object. +func checkPreconditions(w http.ResponseWriter, r *http.Request, etag string) (done bool) { + // RFC 7232 section 6. + im := r.Header.Get("If-Match") + if im != "" && !etagMatch(etag, im) { + w.WriteHeader(http.StatusPreconditionFailed) + return true + } + inm := r.Header.Get("If-None-Match") + if inm != "" && etagMatch(etag, inm) { + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true + } else { + w.WriteHeader(http.StatusPreconditionFailed) + return true + } + } + + return false +} diff --git a/webserver/precondition_test.go b/webserver/precondition_test.go new file mode 100644 index 0000000..5ed61b3 --- /dev/null +++ b/webserver/precondition_test.go @@ -0,0 +1,114 @@ +package webserver + +import ( + "net/http" + "testing" +) + +type testWriter struct { + statusCode int +} + +func (w *testWriter) Header() http.Header { + return nil +} + +func (w *testWriter) Write(buf []byte) (int, error) { + return len(buf), nil +} + +func (w *testWriter) WriteHeader(statusCode int) { + if w.statusCode != 0 { + panic("WriteHeader called twice") + } + w.statusCode = statusCode +} + +func TestEtagMatch(t *testing.T) { + type tst struct { + etag, header string + } + + var match = []tst{ + {`"foo"`, `"foo"`}, + {`"foo"`, ` "foo"`}, + {`"foo"`, `"foo" `}, + {`"foo"`, ` "foo" `}, + {`"foo"`, `"foo", "bar"`}, + {`"foo"`, `"bar", "foo"`}, + {`W/"foo"`, `W/"foo"`}, + } + + var mismatch = []tst{ + {``, ``}, + {``, `*`}, + {``, `"foo"`}, + {`"foo"`, ``}, + {`"foo"`, `"bar"`}, + {`"foo"`, `"bar", "baz"`}, + {`"foo"`, `"baz", "bar"`}, + {`"foo"`, `W/"foo"`}, + {`W/"foo"`, `"foo"`}, + } + + for _, tst := range match { + m := etagMatch(tst.etag, tst.header) + if !m { + t.Errorf("%#v %#v: got %v, expected true", + tst.etag, tst.header, m, + ) + } + } + + for _, tst := range mismatch { + m := etagMatch(tst.etag, tst.header) + if m { + t.Errorf("%#v %#v: got %v, expected false", + tst.etag, tst.header, m, + ) + } + } +} + +func TestCheckPreconditions(t *testing.T) { + var tests = []struct { + method, etag, im, inm string + result int + }{ + {"GET", ``, ``, ``, 0}, + {"GET", ``, `*`, ``, 412}, + {"GET", ``, ``, `*`, 0}, + {"POST", ``, `*`, ``, 412}, + {"POST", ``, ``, `*`, 0}, + {"GET", `"123"`, ``, ``, 0}, + {"GET", `"123"`, `"123"`, ``, 0}, + {"GET", `"123"`, `"124"`, ``, 412}, + {"POST", `"123"`, `"124"`, ``, 412}, + {"GET", `"123"`, `*`, ``, 0}, + {"GET", `"123"`, ``, `"123"`, 304}, + {"POST", `"123"`, ``, `"123"`, 412}, + {"GET", `"123"`, ``, `"124"`, 0}, + {"GET", `"123"`, ``, `*`, 304}, + } + + for _, tst := range tests { + var w testWriter + h := make(http.Header) + if tst.im != "" { + h.Set("If-Match", tst.im) + } + if tst.inm != "" { + h.Set("If-None-Match", tst.inm) + } + r := http.Request{ + Method: tst.method, + Header: h, + } + done := checkPreconditions(&w, &r, tst.etag) + if done != (tst.result != 0) || w.statusCode != tst.result { + t.Errorf("%#v %#v %#v: got %v, expected %v", + tst.etag, tst.im, tst.inm, + w.statusCode, tst.result) + } + } +}