44 lines
956 B
Go
44 lines
956 B
Go
|
package server
|
||
|
|
||
|
import (
|
||
|
"log"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
|
||
|
"github.com/gorilla/websocket"
|
||
|
)
|
||
|
|
||
|
func WebsocketUpgrader(devMode bool) websocket.Upgrader {
|
||
|
return websocket.Upgrader{
|
||
|
CheckOrigin: func(r *http.Request) bool {
|
||
|
if devMode {
|
||
|
return true
|
||
|
} else {
|
||
|
pubEndpoint, err := url.Parse(os.Getenv("PUBLIC_ENDPOINT"))
|
||
|
if err != nil {
|
||
|
log.Printf("Could not parse API_ENDPOINT environment variable as url: %s", err)
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if r.Header.Get("origin") == "" {
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
originURL, err := url.Parse(r.Header.Get("origin"))
|
||
|
if err != nil {
|
||
|
log.Printf("Could not parse origin header of websocket request: %s", err)
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if pubEndpoint.Host == originURL.Host {
|
||
|
return true
|
||
|
} else {
|
||
|
log.Printf("Not allowing websocket request from %s because it doesn't match PUBLIC_ENDPOINT %s", originURL.Host, pubEndpoint.Host)
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
}
|