1
Fork 0

Protect download routes

This commit is contained in:
viktorstrate 2021-09-26 12:02:53 +02:00
parent 92662f9b8e
commit 80c2019b3f
No known key found for this signature in database
GPG Key ID: 3F855605109C1E8A
7 changed files with 335 additions and 232 deletions

View File

@ -55,6 +55,47 @@ func AddMediaShare(db *gorm.DB, userID int, mediaID int, expire *time.Time, pass
return &shareToken, nil
}
func AddAlbumShare(db *gorm.DB, user *models.User, albumID int, expire *time.Time, password *string) (*models.ShareToken, error) {
var count int64
err := db.
Model(&models.Album{}).
Where("EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = albums.id AND user_albums.user_id = ?)", user.ID).
Count(&count).Error
if err != nil {
return nil, errors.Wrap(err, "failed to validate album owner with database")
}
if count == 0 {
return nil, auth.ErrUnauthorized
}
var hashedPassword *string = nil
if password != nil {
hashedPassBytes, err := bcrypt.GenerateFromPassword([]byte(*password), 12)
if err != nil {
return nil, errors.Wrap(err, "failed to hash token password")
}
hashedStr := string(hashedPassBytes)
hashedPassword = &hashedStr
}
shareToken := models.ShareToken{
Value: utils.GenerateToken(),
OwnerID: user.ID,
Expire: expire,
Password: hashedPassword,
AlbumID: &albumID,
MediaID: nil,
}
if err := db.Create(&shareToken).Error; err != nil {
return nil, errors.Wrap(err, "failed to insert new share token into database")
}
return &shareToken, nil
}
func DeleteShareToken(db *gorm.DB, userID int, tokenValue string) (*models.ShareToken, error) {
token, err := getUserToken(db, userID, tokenValue)
if err != nil {

View File

@ -12,7 +12,6 @@ import (
"github.com/photoview/photoview/api/graphql/auth"
"github.com/photoview/photoview/api/graphql/models"
"github.com/photoview/photoview/api/graphql/models/actions"
"github.com/photoview/photoview/api/utils"
"golang.org/x/crypto/bcrypt"
)
@ -100,44 +99,7 @@ func (r *mutationResolver) ShareAlbum(ctx context.Context, albumID int, expire *
return nil, auth.ErrUnauthorized
}
var count int64
err := r.Database.
Model(&models.Album{}).
Where("EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = albums.id AND user_albums.user_id = ?)", user.ID).
Count(&count).Error
if err != nil {
return nil, errors.Wrap(err, "failed to validate album owner with database")
}
if count == 0 {
return nil, auth.ErrUnauthorized
}
var hashedPassword *string = nil
if password != nil {
hashedPassBytes, err := bcrypt.GenerateFromPassword([]byte(*password), 12)
if err != nil {
return nil, errors.Wrap(err, "failed to hash token password")
}
hashedStr := string(hashedPassBytes)
hashedPassword = &hashedStr
}
shareToken := models.ShareToken{
Value: utils.GenerateToken(),
OwnerID: user.ID,
Expire: expire,
Password: hashedPassword,
AlbumID: &albumID,
MediaID: nil,
}
if err := r.Database.Create(&shareToken).Error; err != nil {
return nil, errors.Wrap(err, "failed to insert new share token into database")
}
return &shareToken, nil
return actions.AddAlbumShare(r.Database, user, albumID, expire, password)
}
func (r *mutationResolver) ShareMedia(ctx context.Context, mediaID int, expire *time.Time, password *string) (*models.ShareToken, error) {

View File

@ -1,88 +0,0 @@
package routes
import (
"fmt"
"net/http"
"github.com/photoview/photoview/api/graphql/auth"
"github.com/photoview/photoview/api/graphql/models"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
func authenticateMedia(media *models.Media, db *gorm.DB, r *http.Request) (success bool, responseMessage string, responseStatus int, errorMessage error) {
user := auth.UserFromContext(r.Context())
if user != nil {
var album models.Album
if err := db.First(&album, media.AlbumID).Error; err != nil {
return false, "internal server error", http.StatusInternalServerError, err
}
ownsAlbum, err := user.OwnsAlbum(db, &album)
if err != nil {
return false, "internal server error", http.StatusInternalServerError, err
}
if !ownsAlbum {
return false, "invalid credentials", http.StatusForbidden, nil
}
} else {
// Check if photo is authorized with a share token
token := r.URL.Query().Get("token")
if token == "" {
return false, "unauthorized", http.StatusForbidden, nil
}
var shareToken models.ShareToken
if err := db.Where("value = ?", token).First(&shareToken).Error; err != nil {
return false, "internal server error", http.StatusInternalServerError, err
}
// Validate share token password, if set
if shareToken.Password != nil {
tokenPasswordCookie, err := r.Cookie(fmt.Sprintf("share-token-pw-%s", shareToken.Value))
if err != nil {
return false, "unauthorized", http.StatusForbidden, nil
}
// tokenPassword := r.Header.Get("TokenPassword")
tokenPassword := tokenPasswordCookie.Value
if err := bcrypt.CompareHashAndPassword([]byte(*shareToken.Password), []byte(tokenPassword)); err != nil {
if err == bcrypt.ErrMismatchedHashAndPassword {
return false, "unauthorized", http.StatusForbidden, nil
} else {
return false, "internal server error", http.StatusInternalServerError, err
}
}
}
if shareToken.AlbumID != nil && media.AlbumID != *shareToken.AlbumID {
// Check child albums
var count int
err := db.Raw(`
WITH recursive child_albums AS (
SELECT * FROM albums WHERE parent_album_id = ?
UNION ALL
SELECT child.* FROM albums child JOIN child_albums parent ON parent.id = child.parent_album_id
)
SELECT COUNT(id) FROM child_albums WHERE id = ?
`, *shareToken.AlbumID, media.AlbumID).Find(&count).Error
if err != nil {
return false, "internal server error", http.StatusInternalServerError, err
}
if count == 0 {
return false, "unauthorized", http.StatusForbidden, nil
}
}
if shareToken.MediaID != nil && media.ID != *shareToken.MediaID {
return false, "unauthorized", http.StatusForbidden, nil
}
}
return true, "success", http.StatusAccepted, nil
}

View File

@ -1,104 +0,0 @@
package routes
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/photoview/photoview/api/graphql/auth"
"github.com/photoview/photoview/api/graphql/models"
"github.com/photoview/photoview/api/graphql/models/actions"
"github.com/photoview/photoview/api/test_utils"
"github.com/stretchr/testify/assert"
)
func TestAuthenticateMedia(t *testing.T) {
db := test_utils.DatabaseTest(t)
user, err := models.RegisterUser(db, "username", nil, false)
if !assert.NoError(t, err) {
return
}
album := models.Album{
Title: "my_album",
}
if !assert.NoError(t, db.Model(&user).Association("Albums").Append(&album)) {
return
}
media := models.Media{
Title: "my_media",
Path: "/photos/image.jpg",
AlbumID: album.ID,
}
if !assert.NoError(t, db.Save(&media).Error) {
return
}
t.Run("Authorized request", func(t *testing.T) {
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
ctx := auth.AddUserToContext(req.Context(), user)
req = req.WithContext(ctx)
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
assert.NoError(t, err)
assert.True(t, success)
assert.Equal(t, responseMessage, "success")
assert.Equal(t, responseStatus, http.StatusAccepted)
})
t.Run("Request without access token", func(t *testing.T) {
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
assert.NoError(t, err)
assert.False(t, success)
assert.Equal(t, responseMessage, "unauthorized")
assert.Equal(t, responseStatus, http.StatusForbidden)
})
t.Run("Request without access token", func(t *testing.T) {
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
assert.NoError(t, err)
assert.False(t, success)
assert.Equal(t, responseMessage, "unauthorized")
assert.Equal(t, responseStatus, http.StatusForbidden)
})
expire := time.Now().Add(time.Hour * 24 * 30)
tokenPassword := "token-password-123"
shareToken, err := actions.AddMediaShare(db, user.ID, media.ID, &expire, &tokenPassword)
if !assert.NoError(t, err) {
return
}
t.Run("Request with share token", func(t *testing.T) {
url := fmt.Sprintf("/photo/image.jpg?token=%s", shareToken.Value)
req := httptest.NewRequest("GET", url, strings.NewReader("IMAGE DATA"))
cookie := http.Cookie{
Name: fmt.Sprintf("share-token-pw-%s", shareToken.Value),
Value: tokenPassword,
}
req.AddCookie(&cookie)
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
assert.NoError(t, err)
assert.True(t, success)
assert.Equal(t, responseMessage, "success")
assert.Equal(t, responseStatus, http.StatusAccepted)
})
}

View File

@ -0,0 +1,128 @@
package routes
import (
"fmt"
"net/http"
"github.com/photoview/photoview/api/graphql/auth"
"github.com/photoview/photoview/api/graphql/models"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
func authenticateMedia(media *models.Media, db *gorm.DB, r *http.Request) (success bool, responseMessage string, responseStatus int, errorMessage error) {
user := auth.UserFromContext(r.Context())
if user != nil {
var album models.Album
if err := db.First(&album, media.AlbumID).Error; err != nil {
return false, "internal server error", http.StatusInternalServerError, err
}
ownsAlbum, err := user.OwnsAlbum(db, &album)
if err != nil {
return false, "internal server error", http.StatusInternalServerError, err
}
if !ownsAlbum {
return false, "invalid credentials", http.StatusForbidden, nil
}
} else {
if success, respMsg, respStatus, err := shareTokenFromRequest(db, r, &media.ID, &media.AlbumID); !success {
return success, respMsg, respStatus, err
}
}
return true, "success", http.StatusAccepted, nil
}
func authenticateAlbum(album *models.Album, db *gorm.DB, r *http.Request) (success bool, responseMessage string, responseStatus int, errorMessage error) {
user := auth.UserFromContext(r.Context())
if user != nil {
ownsAlbum, err := user.OwnsAlbum(db, album)
if err != nil {
return false, "internal server error", http.StatusInternalServerError, err
}
if !ownsAlbum {
return false, "invalid credentials", http.StatusForbidden, nil
}
} else {
if success, respMsg, respStatus, err := shareTokenFromRequest(db, r, nil, &album.ID); !success {
return success, respMsg, respStatus, err
}
}
return true, "success", http.StatusAccepted, nil
}
func shareTokenFromRequest(db *gorm.DB, r *http.Request, mediaID *int, albumID *int) (success bool, responseMessage string, responseStatus int, errorMessage error) {
// Check if photo is authorized with a share token
token := r.URL.Query().Get("token")
if token == "" {
return false, "unauthorized", http.StatusForbidden, nil
}
var shareToken models.ShareToken
if err := db.Where("value = ?", token).First(&shareToken).Error; err != nil {
return false, "internal server error", http.StatusInternalServerError, err
}
// Validate share token password, if set
if shareToken.Password != nil {
tokenPasswordCookie, err := r.Cookie(fmt.Sprintf("share-token-pw-%s", shareToken.Value))
if err != nil {
return false, "unauthorized", http.StatusForbidden, nil
}
// tokenPassword := r.Header.Get("TokenPassword")
tokenPassword := tokenPasswordCookie.Value
if err := bcrypt.CompareHashAndPassword([]byte(*shareToken.Password), []byte(tokenPassword)); err != nil {
if err == bcrypt.ErrMismatchedHashAndPassword {
return false, "unauthorized", http.StatusForbidden, nil
} else {
return false, "internal server error", http.StatusInternalServerError, err
}
}
}
if shareToken.AlbumID != nil && albumID == nil {
return false, "unauthorized", http.StatusForbidden, nil
}
if shareToken.MediaID != nil && mediaID == nil {
return false, "unauthorized", http.StatusForbidden, nil
}
if shareToken.AlbumID != nil && *albumID != *shareToken.AlbumID {
// Check child albums
var count int
err := db.Raw(`
WITH recursive child_albums AS (
SELECT * FROM albums WHERE parent_album_id = ?
UNION ALL
SELECT child.* FROM albums child JOIN child_albums parent ON parent.id = child.parent_album_id
)
SELECT COUNT(id) FROM child_albums WHERE id = ?
`, *shareToken.AlbumID, albumID).Find(&count).Error
if err != nil {
return false, "internal server error", http.StatusInternalServerError, err
}
if count == 0 {
return false, "unauthorized", http.StatusForbidden, nil
}
}
if shareToken.MediaID != nil && *mediaID != *shareToken.MediaID {
return false, "unauthorized", http.StatusForbidden, nil
}
return true, "", 0, nil
}

View File

@ -0,0 +1,147 @@
package routes
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/photoview/photoview/api/graphql/auth"
"github.com/photoview/photoview/api/graphql/models"
"github.com/photoview/photoview/api/graphql/models/actions"
"github.com/photoview/photoview/api/test_utils"
"github.com/stretchr/testify/assert"
)
func TestAuthenticateRoute(t *testing.T) {
db := test_utils.DatabaseTest(t)
user, err := models.RegisterUser(db, "username", nil, false)
if !assert.NoError(t, err) {
return
}
album := models.Album{
Title: "my_album",
Path: "/photos",
}
if !assert.NoError(t, db.Model(&user).Association("Albums").Append(&album)) {
return
}
media := models.Media{
Title: "my_media",
Path: "/photos/image.jpg",
AlbumID: album.ID,
}
if !assert.NoError(t, db.Save(&media).Error) {
return
}
t.Run("Authenticate Media", func(t *testing.T) {
t.Run("Authorized request", func(t *testing.T) {
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
ctx := auth.AddUserToContext(req.Context(), user)
req = req.WithContext(ctx)
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
assert.NoError(t, err)
assert.True(t, success)
assert.Equal(t, responseMessage, "success")
assert.Equal(t, responseStatus, http.StatusAccepted)
})
t.Run("Request without access token", func(t *testing.T) {
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
assert.NoError(t, err)
assert.False(t, success)
assert.Equal(t, responseMessage, "unauthorized")
assert.Equal(t, responseStatus, http.StatusForbidden)
})
expire := time.Now().Add(time.Hour * 24 * 30)
tokenPassword := "token-password-123"
shareToken, err := actions.AddMediaShare(db, user.ID, media.ID, &expire, &tokenPassword)
if !assert.NoError(t, err) {
return
}
t.Run("Request with share token", func(t *testing.T) {
url := fmt.Sprintf("/photo/image.jpg?token=%s", shareToken.Value)
req := httptest.NewRequest("GET", url, strings.NewReader("IMAGE DATA"))
cookie := http.Cookie{
Name: fmt.Sprintf("share-token-pw-%s", shareToken.Value),
Value: tokenPassword,
}
req.AddCookie(&cookie)
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
assert.NoError(t, err)
assert.True(t, success)
assert.Equal(t, responseMessage, "success")
assert.Equal(t, responseStatus, http.StatusAccepted)
})
})
t.Run("Authenticate Album", func(t *testing.T) {
t.Run("Authorized request", func(t *testing.T) {
req := httptest.NewRequest("GET", "/download/album/1", strings.NewReader("ALBUM DATA"))
ctx := auth.AddUserToContext(req.Context(), user)
req = req.WithContext(ctx)
success, responseMessage, responseStatus, err := authenticateAlbum(&album, db, req)
assert.NoError(t, err)
assert.True(t, success)
assert.Equal(t, responseMessage, "success")
assert.Equal(t, responseStatus, http.StatusAccepted)
})
t.Run("Request without access token", func(t *testing.T) {
req := httptest.NewRequest("GET", "/download/album/1", strings.NewReader("ALBUM DATA"))
success, responseMessage, responseStatus, err := authenticateAlbum(&album, db, req)
assert.NoError(t, err)
assert.False(t, success)
assert.Equal(t, responseMessage, "unauthorized")
assert.Equal(t, responseStatus, http.StatusForbidden)
})
expire := time.Now().Add(time.Hour * 24 * 30)
tokenPassword := "token-password-123"
shareToken, err := actions.AddAlbumShare(db, user, album.ID, &expire, &tokenPassword)
if !assert.NoError(t, err) {
return
}
t.Run("Request with share token", func(t *testing.T) {
url := fmt.Sprintf("/download/album/1?token=%s", shareToken.Value)
req := httptest.NewRequest("GET", url, strings.NewReader("ALBUM DATA"))
cookie := http.Cookie{
Name: fmt.Sprintf("share-token-pw-%s", shareToken.Value),
Value: tokenPassword,
}
req.AddCookie(&cookie)
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
assert.NoError(t, err)
assert.True(t, success)
assert.Equal(t, responseMessage, "success")
assert.Equal(t, responseStatus, http.StatusAccepted)
})
})
}

View File

@ -7,6 +7,7 @@ import (
"log"
"net/http"
"os"
"strings"
"github.com/gorilla/mux"
"github.com/photoview/photoview/api/graphql/models"
@ -17,6 +18,7 @@ func RegisterDownloadRoutes(db *gorm.DB, router *mux.Router) {
router.HandleFunc("/album/{album_id}/{media_purpose}", func(w http.ResponseWriter, r *http.Request) {
albumID := mux.Vars(r)["album_id"]
mediaPurpose := mux.Vars(r)["media_purpose"]
mediaPurposeList := strings.SplitN(mediaPurpose, ",", 10)
var album models.Album
if err := db.Find(&album, albumID).Error; err != nil {
@ -25,13 +27,28 @@ func RegisterDownloadRoutes(db *gorm.DB, router *mux.Router) {
return
}
if success, response, status, err := authenticateAlbum(&album, db, r); !success {
if err != nil {
log.Printf("WARN: error authenticating album for download: %v\n", err)
}
w.WriteHeader(status)
w.Write([]byte(response))
return
}
var mediaURLs []*models.MediaURL
if err := db.Joins("Media").Where("media.album_id = ?", album.ID).Where("media_urls.purpose = ?", mediaPurpose).Find(&mediaURLs).Error; err != nil {
if err := db.Joins("Media").Where("media.album_id = ?", album.ID).Where("media_urls.purpose IN (?)", mediaPurposeList).Find(&mediaURLs).Error; err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("internal server error"))
return
}
if len(mediaURLs) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("no media found"))
return
}
w.Header().Set("Content-Type", "application/zip")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s.zip\"", album.Title))