1
Fork 0
photoview/api/server/websocket.go

44 lines
956 B
Go
Raw Normal View History

2020-02-21 17:53:04 +01:00
package server
import (
"log"
"net/http"
"net/url"
"os"
"github.com/gorilla/websocket"
)
func WebsocketUpgrader(devMode bool) websocket.Upgrader {
return websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
if devMode {
return true
} else {
pubEndpoint, err := url.Parse(os.Getenv("PUBLIC_ENDPOINT"))
if err != nil {
log.Printf("Could not parse API_ENDPOINT environment variable as url: %s", err)
return false
}
if r.Header.Get("origin") == "" {
return true
}
originURL, err := url.Parse(r.Header.Get("origin"))
if err != nil {
log.Printf("Could not parse origin header of websocket request: %s", err)
return false
}
if pubEndpoint.Host == originURL.Host {
return true
} else {
log.Printf("Not allowing websocket request from %s because it doesn't match PUBLIC_ENDPOINT %s", originURL.Host, pubEndpoint.Host)
return false
}
}
},
}
}