Protect download routes
This commit is contained in:
parent
92662f9b8e
commit
80c2019b3f
|
@ -55,6 +55,47 @@ func AddMediaShare(db *gorm.DB, userID int, mediaID int, expire *time.Time, pass
|
||||||
return &shareToken, nil
|
return &shareToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AddAlbumShare(db *gorm.DB, user *models.User, albumID int, expire *time.Time, password *string) (*models.ShareToken, error) {
|
||||||
|
var count int64
|
||||||
|
err := db.
|
||||||
|
Model(&models.Album{}).
|
||||||
|
Where("EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = albums.id AND user_albums.user_id = ?)", user.ID).
|
||||||
|
Count(&count).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to validate album owner with database")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
return nil, auth.ErrUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
var hashedPassword *string = nil
|
||||||
|
if password != nil {
|
||||||
|
hashedPassBytes, err := bcrypt.GenerateFromPassword([]byte(*password), 12)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to hash token password")
|
||||||
|
}
|
||||||
|
hashedStr := string(hashedPassBytes)
|
||||||
|
hashedPassword = &hashedStr
|
||||||
|
}
|
||||||
|
|
||||||
|
shareToken := models.ShareToken{
|
||||||
|
Value: utils.GenerateToken(),
|
||||||
|
OwnerID: user.ID,
|
||||||
|
Expire: expire,
|
||||||
|
Password: hashedPassword,
|
||||||
|
AlbumID: &albumID,
|
||||||
|
MediaID: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Create(&shareToken).Error; err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to insert new share token into database")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &shareToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
func DeleteShareToken(db *gorm.DB, userID int, tokenValue string) (*models.ShareToken, error) {
|
func DeleteShareToken(db *gorm.DB, userID int, tokenValue string) (*models.ShareToken, error) {
|
||||||
token, err := getUserToken(db, userID, tokenValue)
|
token, err := getUserToken(db, userID, tokenValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/photoview/photoview/api/graphql/auth"
|
"github.com/photoview/photoview/api/graphql/auth"
|
||||||
"github.com/photoview/photoview/api/graphql/models"
|
"github.com/photoview/photoview/api/graphql/models"
|
||||||
"github.com/photoview/photoview/api/graphql/models/actions"
|
"github.com/photoview/photoview/api/graphql/models/actions"
|
||||||
"github.com/photoview/photoview/api/utils"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -100,44 +99,7 @@ func (r *mutationResolver) ShareAlbum(ctx context.Context, albumID int, expire *
|
||||||
return nil, auth.ErrUnauthorized
|
return nil, auth.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
var count int64
|
return actions.AddAlbumShare(r.Database, user, albumID, expire, password)
|
||||||
err := r.Database.
|
|
||||||
Model(&models.Album{}).
|
|
||||||
Where("EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = albums.id AND user_albums.user_id = ?)", user.ID).
|
|
||||||
Count(&count).Error
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to validate album owner with database")
|
|
||||||
}
|
|
||||||
|
|
||||||
if count == 0 {
|
|
||||||
return nil, auth.ErrUnauthorized
|
|
||||||
}
|
|
||||||
|
|
||||||
var hashedPassword *string = nil
|
|
||||||
if password != nil {
|
|
||||||
hashedPassBytes, err := bcrypt.GenerateFromPassword([]byte(*password), 12)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to hash token password")
|
|
||||||
}
|
|
||||||
hashedStr := string(hashedPassBytes)
|
|
||||||
hashedPassword = &hashedStr
|
|
||||||
}
|
|
||||||
|
|
||||||
shareToken := models.ShareToken{
|
|
||||||
Value: utils.GenerateToken(),
|
|
||||||
OwnerID: user.ID,
|
|
||||||
Expire: expire,
|
|
||||||
Password: hashedPassword,
|
|
||||||
AlbumID: &albumID,
|
|
||||||
MediaID: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.Database.Create(&shareToken).Error; err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to insert new share token into database")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &shareToken, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *mutationResolver) ShareMedia(ctx context.Context, mediaID int, expire *time.Time, password *string) (*models.ShareToken, error) {
|
func (r *mutationResolver) ShareMedia(ctx context.Context, mediaID int, expire *time.Time, password *string) (*models.ShareToken, error) {
|
||||||
|
|
|
@ -1,88 +0,0 @@
|
||||||
package routes
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/photoview/photoview/api/graphql/auth"
|
|
||||||
"github.com/photoview/photoview/api/graphql/models"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func authenticateMedia(media *models.Media, db *gorm.DB, r *http.Request) (success bool, responseMessage string, responseStatus int, errorMessage error) {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
|
|
||||||
if user != nil {
|
|
||||||
var album models.Album
|
|
||||||
if err := db.First(&album, media.AlbumID).Error; err != nil {
|
|
||||||
return false, "internal server error", http.StatusInternalServerError, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ownsAlbum, err := user.OwnsAlbum(db, &album)
|
|
||||||
if err != nil {
|
|
||||||
return false, "internal server error", http.StatusInternalServerError, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ownsAlbum {
|
|
||||||
return false, "invalid credentials", http.StatusForbidden, nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Check if photo is authorized with a share token
|
|
||||||
token := r.URL.Query().Get("token")
|
|
||||||
if token == "" {
|
|
||||||
return false, "unauthorized", http.StatusForbidden, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var shareToken models.ShareToken
|
|
||||||
if err := db.Where("value = ?", token).First(&shareToken).Error; err != nil {
|
|
||||||
return false, "internal server error", http.StatusInternalServerError, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate share token password, if set
|
|
||||||
if shareToken.Password != nil {
|
|
||||||
tokenPasswordCookie, err := r.Cookie(fmt.Sprintf("share-token-pw-%s", shareToken.Value))
|
|
||||||
if err != nil {
|
|
||||||
return false, "unauthorized", http.StatusForbidden, nil
|
|
||||||
}
|
|
||||||
// tokenPassword := r.Header.Get("TokenPassword")
|
|
||||||
tokenPassword := tokenPasswordCookie.Value
|
|
||||||
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(*shareToken.Password), []byte(tokenPassword)); err != nil {
|
|
||||||
if err == bcrypt.ErrMismatchedHashAndPassword {
|
|
||||||
return false, "unauthorized", http.StatusForbidden, nil
|
|
||||||
} else {
|
|
||||||
return false, "internal server error", http.StatusInternalServerError, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shareToken.AlbumID != nil && media.AlbumID != *shareToken.AlbumID {
|
|
||||||
// Check child albums
|
|
||||||
|
|
||||||
var count int
|
|
||||||
err := db.Raw(`
|
|
||||||
WITH recursive child_albums AS (
|
|
||||||
SELECT * FROM albums WHERE parent_album_id = ?
|
|
||||||
UNION ALL
|
|
||||||
SELECT child.* FROM albums child JOIN child_albums parent ON parent.id = child.parent_album_id
|
|
||||||
)
|
|
||||||
SELECT COUNT(id) FROM child_albums WHERE id = ?
|
|
||||||
`, *shareToken.AlbumID, media.AlbumID).Find(&count).Error
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, "internal server error", http.StatusInternalServerError, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if count == 0 {
|
|
||||||
return false, "unauthorized", http.StatusForbidden, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shareToken.MediaID != nil && media.ID != *shareToken.MediaID {
|
|
||||||
return false, "unauthorized", http.StatusForbidden, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, "success", http.StatusAccepted, nil
|
|
||||||
}
|
|
|
@ -1,104 +0,0 @@
|
||||||
package routes
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/photoview/photoview/api/graphql/auth"
|
|
||||||
"github.com/photoview/photoview/api/graphql/models"
|
|
||||||
"github.com/photoview/photoview/api/graphql/models/actions"
|
|
||||||
"github.com/photoview/photoview/api/test_utils"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAuthenticateMedia(t *testing.T) {
|
|
||||||
db := test_utils.DatabaseTest(t)
|
|
||||||
|
|
||||||
user, err := models.RegisterUser(db, "username", nil, false)
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
album := models.Album{
|
|
||||||
Title: "my_album",
|
|
||||||
}
|
|
||||||
|
|
||||||
if !assert.NoError(t, db.Model(&user).Association("Albums").Append(&album)) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
media := models.Media{
|
|
||||||
Title: "my_media",
|
|
||||||
Path: "/photos/image.jpg",
|
|
||||||
AlbumID: album.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !assert.NoError(t, db.Save(&media).Error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("Authorized request", func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
|
|
||||||
ctx := auth.AddUserToContext(req.Context(), user)
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
|
|
||||||
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, success)
|
|
||||||
assert.Equal(t, responseMessage, "success")
|
|
||||||
assert.Equal(t, responseStatus, http.StatusAccepted)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Request without access token", func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
|
|
||||||
|
|
||||||
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.False(t, success)
|
|
||||||
assert.Equal(t, responseMessage, "unauthorized")
|
|
||||||
assert.Equal(t, responseStatus, http.StatusForbidden)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Request without access token", func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
|
|
||||||
|
|
||||||
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.False(t, success)
|
|
||||||
assert.Equal(t, responseMessage, "unauthorized")
|
|
||||||
assert.Equal(t, responseStatus, http.StatusForbidden)
|
|
||||||
})
|
|
||||||
|
|
||||||
expire := time.Now().Add(time.Hour * 24 * 30)
|
|
||||||
tokenPassword := "token-password-123"
|
|
||||||
shareToken, err := actions.AddMediaShare(db, user.ID, media.ID, &expire, &tokenPassword)
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("Request with share token", func(t *testing.T) {
|
|
||||||
url := fmt.Sprintf("/photo/image.jpg?token=%s", shareToken.Value)
|
|
||||||
req := httptest.NewRequest("GET", url, strings.NewReader("IMAGE DATA"))
|
|
||||||
|
|
||||||
cookie := http.Cookie{
|
|
||||||
Name: fmt.Sprintf("share-token-pw-%s", shareToken.Value),
|
|
||||||
Value: tokenPassword,
|
|
||||||
}
|
|
||||||
req.AddCookie(&cookie)
|
|
||||||
|
|
||||||
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, success)
|
|
||||||
assert.Equal(t, responseMessage, "success")
|
|
||||||
assert.Equal(t, responseStatus, http.StatusAccepted)
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/photoview/photoview/api/graphql/auth"
|
||||||
|
"github.com/photoview/photoview/api/graphql/models"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func authenticateMedia(media *models.Media, db *gorm.DB, r *http.Request) (success bool, responseMessage string, responseStatus int, errorMessage error) {
|
||||||
|
user := auth.UserFromContext(r.Context())
|
||||||
|
|
||||||
|
if user != nil {
|
||||||
|
var album models.Album
|
||||||
|
if err := db.First(&album, media.AlbumID).Error; err != nil {
|
||||||
|
return false, "internal server error", http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ownsAlbum, err := user.OwnsAlbum(db, &album)
|
||||||
|
if err != nil {
|
||||||
|
return false, "internal server error", http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ownsAlbum {
|
||||||
|
return false, "invalid credentials", http.StatusForbidden, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if success, respMsg, respStatus, err := shareTokenFromRequest(db, r, &media.ID, &media.AlbumID); !success {
|
||||||
|
return success, respMsg, respStatus, err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, "success", http.StatusAccepted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func authenticateAlbum(album *models.Album, db *gorm.DB, r *http.Request) (success bool, responseMessage string, responseStatus int, errorMessage error) {
|
||||||
|
user := auth.UserFromContext(r.Context())
|
||||||
|
|
||||||
|
if user != nil {
|
||||||
|
ownsAlbum, err := user.OwnsAlbum(db, album)
|
||||||
|
if err != nil {
|
||||||
|
return false, "internal server error", http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ownsAlbum {
|
||||||
|
return false, "invalid credentials", http.StatusForbidden, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if success, respMsg, respStatus, err := shareTokenFromRequest(db, r, nil, &album.ID); !success {
|
||||||
|
return success, respMsg, respStatus, err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, "success", http.StatusAccepted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func shareTokenFromRequest(db *gorm.DB, r *http.Request, mediaID *int, albumID *int) (success bool, responseMessage string, responseStatus int, errorMessage error) {
|
||||||
|
// Check if photo is authorized with a share token
|
||||||
|
token := r.URL.Query().Get("token")
|
||||||
|
if token == "" {
|
||||||
|
return false, "unauthorized", http.StatusForbidden, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var shareToken models.ShareToken
|
||||||
|
|
||||||
|
if err := db.Where("value = ?", token).First(&shareToken).Error; err != nil {
|
||||||
|
return false, "internal server error", http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate share token password, if set
|
||||||
|
if shareToken.Password != nil {
|
||||||
|
tokenPasswordCookie, err := r.Cookie(fmt.Sprintf("share-token-pw-%s", shareToken.Value))
|
||||||
|
if err != nil {
|
||||||
|
return false, "unauthorized", http.StatusForbidden, nil
|
||||||
|
}
|
||||||
|
// tokenPassword := r.Header.Get("TokenPassword")
|
||||||
|
tokenPassword := tokenPasswordCookie.Value
|
||||||
|
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(*shareToken.Password), []byte(tokenPassword)); err != nil {
|
||||||
|
if err == bcrypt.ErrMismatchedHashAndPassword {
|
||||||
|
return false, "unauthorized", http.StatusForbidden, nil
|
||||||
|
} else {
|
||||||
|
return false, "internal server error", http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shareToken.AlbumID != nil && albumID == nil {
|
||||||
|
return false, "unauthorized", http.StatusForbidden, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if shareToken.MediaID != nil && mediaID == nil {
|
||||||
|
return false, "unauthorized", http.StatusForbidden, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if shareToken.AlbumID != nil && *albumID != *shareToken.AlbumID {
|
||||||
|
// Check child albums
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err := db.Raw(`
|
||||||
|
WITH recursive child_albums AS (
|
||||||
|
SELECT * FROM albums WHERE parent_album_id = ?
|
||||||
|
UNION ALL
|
||||||
|
SELECT child.* FROM albums child JOIN child_albums parent ON parent.id = child.parent_album_id
|
||||||
|
)
|
||||||
|
SELECT COUNT(id) FROM child_albums WHERE id = ?
|
||||||
|
`, *shareToken.AlbumID, albumID).Find(&count).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, "internal server error", http.StatusInternalServerError, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
return false, "unauthorized", http.StatusForbidden, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shareToken.MediaID != nil && *mediaID != *shareToken.MediaID {
|
||||||
|
return false, "unauthorized", http.StatusForbidden, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, "", 0, nil
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/photoview/photoview/api/graphql/auth"
|
||||||
|
"github.com/photoview/photoview/api/graphql/models"
|
||||||
|
"github.com/photoview/photoview/api/graphql/models/actions"
|
||||||
|
"github.com/photoview/photoview/api/test_utils"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthenticateRoute(t *testing.T) {
|
||||||
|
db := test_utils.DatabaseTest(t)
|
||||||
|
|
||||||
|
user, err := models.RegisterUser(db, "username", nil, false)
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
album := models.Album{
|
||||||
|
Title: "my_album",
|
||||||
|
Path: "/photos",
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.NoError(t, db.Model(&user).Association("Albums").Append(&album)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
media := models.Media{
|
||||||
|
Title: "my_media",
|
||||||
|
Path: "/photos/image.jpg",
|
||||||
|
AlbumID: album.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.NoError(t, db.Save(&media).Error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Authenticate Media", func(t *testing.T) {
|
||||||
|
t.Run("Authorized request", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
|
||||||
|
ctx := auth.AddUserToContext(req.Context(), user)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, success)
|
||||||
|
assert.Equal(t, responseMessage, "success")
|
||||||
|
assert.Equal(t, responseStatus, http.StatusAccepted)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Request without access token", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
|
||||||
|
|
||||||
|
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, success)
|
||||||
|
assert.Equal(t, responseMessage, "unauthorized")
|
||||||
|
assert.Equal(t, responseStatus, http.StatusForbidden)
|
||||||
|
})
|
||||||
|
|
||||||
|
expire := time.Now().Add(time.Hour * 24 * 30)
|
||||||
|
tokenPassword := "token-password-123"
|
||||||
|
shareToken, err := actions.AddMediaShare(db, user.ID, media.ID, &expire, &tokenPassword)
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Request with share token", func(t *testing.T) {
|
||||||
|
url := fmt.Sprintf("/photo/image.jpg?token=%s", shareToken.Value)
|
||||||
|
req := httptest.NewRequest("GET", url, strings.NewReader("IMAGE DATA"))
|
||||||
|
|
||||||
|
cookie := http.Cookie{
|
||||||
|
Name: fmt.Sprintf("share-token-pw-%s", shareToken.Value),
|
||||||
|
Value: tokenPassword,
|
||||||
|
}
|
||||||
|
req.AddCookie(&cookie)
|
||||||
|
|
||||||
|
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, success)
|
||||||
|
assert.Equal(t, responseMessage, "success")
|
||||||
|
assert.Equal(t, responseStatus, http.StatusAccepted)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Authenticate Album", func(t *testing.T) {
|
||||||
|
t.Run("Authorized request", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/download/album/1", strings.NewReader("ALBUM DATA"))
|
||||||
|
ctx := auth.AddUserToContext(req.Context(), user)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
success, responseMessage, responseStatus, err := authenticateAlbum(&album, db, req)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, success)
|
||||||
|
assert.Equal(t, responseMessage, "success")
|
||||||
|
assert.Equal(t, responseStatus, http.StatusAccepted)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Request without access token", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/download/album/1", strings.NewReader("ALBUM DATA"))
|
||||||
|
|
||||||
|
success, responseMessage, responseStatus, err := authenticateAlbum(&album, db, req)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, success)
|
||||||
|
assert.Equal(t, responseMessage, "unauthorized")
|
||||||
|
assert.Equal(t, responseStatus, http.StatusForbidden)
|
||||||
|
})
|
||||||
|
|
||||||
|
expire := time.Now().Add(time.Hour * 24 * 30)
|
||||||
|
tokenPassword := "token-password-123"
|
||||||
|
shareToken, err := actions.AddAlbumShare(db, user, album.ID, &expire, &tokenPassword)
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Request with share token", func(t *testing.T) {
|
||||||
|
url := fmt.Sprintf("/download/album/1?token=%s", shareToken.Value)
|
||||||
|
req := httptest.NewRequest("GET", url, strings.NewReader("ALBUM DATA"))
|
||||||
|
|
||||||
|
cookie := http.Cookie{
|
||||||
|
Name: fmt.Sprintf("share-token-pw-%s", shareToken.Value),
|
||||||
|
Value: tokenPassword,
|
||||||
|
}
|
||||||
|
req.AddCookie(&cookie)
|
||||||
|
|
||||||
|
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, success)
|
||||||
|
assert.Equal(t, responseMessage, "success")
|
||||||
|
assert.Equal(t, responseStatus, http.StatusAccepted)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/photoview/photoview/api/graphql/models"
|
"github.com/photoview/photoview/api/graphql/models"
|
||||||
|
@ -17,6 +18,7 @@ func RegisterDownloadRoutes(db *gorm.DB, router *mux.Router) {
|
||||||
router.HandleFunc("/album/{album_id}/{media_purpose}", func(w http.ResponseWriter, r *http.Request) {
|
router.HandleFunc("/album/{album_id}/{media_purpose}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
albumID := mux.Vars(r)["album_id"]
|
albumID := mux.Vars(r)["album_id"]
|
||||||
mediaPurpose := mux.Vars(r)["media_purpose"]
|
mediaPurpose := mux.Vars(r)["media_purpose"]
|
||||||
|
mediaPurposeList := strings.SplitN(mediaPurpose, ",", 10)
|
||||||
|
|
||||||
var album models.Album
|
var album models.Album
|
||||||
if err := db.Find(&album, albumID).Error; err != nil {
|
if err := db.Find(&album, albumID).Error; err != nil {
|
||||||
|
@ -25,13 +27,28 @@ func RegisterDownloadRoutes(db *gorm.DB, router *mux.Router) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if success, response, status, err := authenticateAlbum(&album, db, r); !success {
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("WARN: error authenticating album for download: %v\n", err)
|
||||||
|
}
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write([]byte(response))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var mediaURLs []*models.MediaURL
|
var mediaURLs []*models.MediaURL
|
||||||
if err := db.Joins("Media").Where("media.album_id = ?", album.ID).Where("media_urls.purpose = ?", mediaPurpose).Find(&mediaURLs).Error; err != nil {
|
if err := db.Joins("Media").Where("media.album_id = ?", album.ID).Where("media_urls.purpose IN (?)", mediaPurposeList).Find(&mediaURLs).Error; err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
w.Write([]byte("internal server error"))
|
w.Write([]byte("internal server error"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(mediaURLs) == 0 {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte("no media found"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/zip")
|
w.Header().Set("Content-Type", "application/zip")
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s.zip\"", album.Title))
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s.zip\"", album.Title))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue