1
Fork 0

Fix faces not getting scanned

- This fixes #344
- Add integration tests for face recognition
- Properly check that the user own the queried album
This commit is contained in:
viktorstrate 2021-04-26 12:21:15 +02:00
parent d03923992c
commit 1029b61a4c
No known key found for this signature in database
GPG Key ID: 3F855605109C1E8A
14 changed files with 126 additions and 92 deletions

View File

@ -36,6 +36,8 @@ models:
resolver: true
faces:
resolver: true
mediaType:
resolver: true
MediaURL:
model: github.com/photoview/photoview/api/graphql/models.MediaURL
MediaEXIF:
@ -60,3 +62,5 @@ models:
model: github.com/photoview/photoview/api/graphql/models.FaceRectangle
SiteInfo:
model: github.com/photoview/photoview/api/graphql/models.SiteInfo
MediaType:
model: github.com/photoview/photoview/api/graphql/models.MediaType

View File

@ -11993,13 +11993,19 @@ func (ec *executionContext) marshalNMediaDownload2ᚖgithubᚗcomᚋphotoviewᚋ
}
func (ec *executionContext) unmarshalNMediaType2githubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaType(ctx context.Context, v interface{}) (models.MediaType, error) {
var res models.MediaType
err := res.UnmarshalGQL(v)
tmp, err := graphql.UnmarshalString(v)
res := models.MediaType(tmp)
return res, graphql.ErrorOnPath(ctx, err)
}
func (ec *executionContext) marshalNMediaType2githubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaType(ctx context.Context, sel ast.SelectionSet, v models.MediaType) graphql.Marshaler {
return v
res := graphql.MarshalString(string(v))
if res == graphql.Null {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "must not be null")
}
}
return res
}
func (ec *executionContext) marshalNMediaURL2ᚖgithubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaURL(ctx context.Context, sel ast.SelectionSet, v *models.MediaURL) graphql.Marshaler {

View File

@ -32,7 +32,10 @@ func (a *Album) BeforeSave(tx *gorm.DB) (err error) {
// GetChildren performs a recursive query to get all the children of the album.
// An optional filter can be provided that can be used to modify the query on the children.
func (a *Album) GetChildren(db *gorm.DB, filter func(*gorm.DB) *gorm.DB) (children []*Album, err error) {
// SELECT * FROM sub_albums
return GetChildrenFromAlbums(db, filter, []int{a.ID})
}
func GetChildrenFromAlbums(db *gorm.DB, filter func(*gorm.DB) *gorm.DB, albumIDs []int) (children []*Album, err error) {
query := db.Model(&Album{}).Table("sub_albums")
if filter != nil {
@ -41,13 +44,13 @@ func (a *Album) GetChildren(db *gorm.DB, filter func(*gorm.DB) *gorm.DB) (childr
err = db.Raw(`
WITH recursive sub_albums AS (
SELECT * FROM albums AS root WHERE id = ?
SELECT * FROM albums AS root WHERE id IN (?)
UNION ALL
SELECT child.* FROM albums AS child JOIN sub_albums ON child.parent_album_id = sub_albums.id
)
?
`, a.ID, query).Find(&children).Error
`, albumIDs, query).Find(&children).Error
return children, err
}

View File

@ -121,47 +121,6 @@ func (e LanguageTranslation) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
type MediaType string
const (
MediaTypePhoto MediaType = "Photo"
MediaTypeVideo MediaType = "Video"
)
var AllMediaType = []MediaType{
MediaTypePhoto,
MediaTypeVideo,
}
func (e MediaType) IsValid() bool {
switch e {
case MediaTypePhoto, MediaTypeVideo:
return true
}
return false
}
func (e MediaType) String() string {
return string(e)
}
func (e *MediaType) UnmarshalGQL(v interface{}) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}
*e = MediaType(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid MediaType", str)
}
return nil
}
func (e MediaType) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
type NotificationType string
const (

View File

@ -42,30 +42,19 @@ func (m *Media) BeforeSave(tx *gorm.DB) error {
// Update path hash
m.PathHash = MD5Hash(m.Path)
// Save media type as lowercase for better compatibility
m.Type = MediaType(strings.ToLower(string(m.Type)))
return nil
}
func (m *Media) AfterFind(tx *gorm.DB) error {
type MediaType string
// Convert lowercased media type back
lowercasedType := strings.ToLower(string(m.Type))
foundType := false
for _, t := range AllMediaType {
if strings.ToLower(string(t)) == lowercasedType {
m.Type = t
foundType = true
break
}
}
const (
MediaTypePhoto MediaType = "photo"
MediaTypeVideo MediaType = "video"
)
if foundType == false {
return errors.New(fmt.Sprintf("Failed to parse media from DB: Invalid media type: %s", m.Type))
}
return nil
var AllMediaType = []MediaType{
MediaTypePhoto,
MediaTypeVideo,
}
type MediaPurpose string

View File

@ -173,11 +173,29 @@ func (user *User) FillAlbums(db *gorm.DB) error {
}
func (user *User) OwnsAlbum(db *gorm.DB, album *Album) (bool, error) {
// TODO: Implement this
panic("not implemented")
if err := user.FillAlbums(db); err != nil {
return false, err
}
albumIDs := make([]int, 0)
for _, a := range user.Albums {
albumIDs = append(albumIDs, a.ID)
}
filter := func(query *gorm.DB) *gorm.DB {
return query.Where("id = ?", album.ID)
}
ownedAlbum, err := GetChildrenFromAlbums(db, filter, albumIDs)
if err != nil {
return false, err
}
return len(ownedAlbum) > 0, nil
}
func (user *User) OwnsMedia(db *gorm.DB, media *Media) (bool, error) {
// TODO: implement this
panic("not implemented")
}
// func (user *User) OwnsMedia(db *gorm.DB, media *Media) (bool, error) {
// // TODO: implement this
// panic("not implemented")
// }

View File

@ -69,15 +69,6 @@ func (r faceGroupResolver) ImageFaces(ctx context.Context, obj *models.FaceGroup
return nil, err
}
r.Database.Transaction(func(tx *gorm.DB) error {
for i := range imageFaces {
if err := imageFaces[i].Media.AfterFind(tx); err != nil {
return err
}
}
return nil
})
return imageFaces, nil
}

View File

@ -2,6 +2,7 @@ package resolvers
import (
"context"
"strings"
api "github.com/photoview/photoview/api/graphql"
"github.com/photoview/photoview/api/graphql/auth"
@ -108,6 +109,11 @@ func (r *Resolver) Media() api.MediaResolver {
return &mediaResolver{r}
}
func (r *mediaResolver) MediaType(ctx context.Context, media *models.Media) (*models.MediaType, error) {
formattedType := models.MediaType(strings.Title(string(media.Type)))
return &formattedType, nil
}
func (r *mediaResolver) Shares(ctx context.Context, media *models.Media) ([]*models.ShareToken, error) {
var shareTokens []*models.ShareToken
if err := r.Database.Where("media_id = ?", media.ID).Find(&shareTokens).Error; err != nil {

View File

@ -2,11 +2,11 @@ package face_detection
import (
"log"
"path/filepath"
"sync"
"github.com/Kagami/go-face"
"github.com/photoview/photoview/api/graphql/models"
"github.com/photoview/photoview/api/utils"
"github.com/pkg/errors"
"gorm.io/gorm"
)
@ -25,7 +25,7 @@ func InitializeFaceDetector(db *gorm.DB) error {
log.Println("Initializing face detector")
rec, err := face.NewRecognizer(filepath.Join("data", "models"))
rec, err := face.NewRecognizer(utils.FaceRecognitionModelsPath())
if err != nil {
return errors.Wrap(err, "initialize facedetect recognizer")
}
@ -107,6 +107,8 @@ func (fd *FaceDetector) DetectFaces(db *gorm.DB, media *models.Media) error {
return err
}
log.Printf("Face thumb path: %v %v\n", thumbnailPath, fd)
fd.mutex.Lock()
faces, err := fd.rec.RecognizeFile(thumbnailPath)
fd.mutex.Unlock()

View File

@ -3,9 +3,11 @@ package scanner_test
import (
"os"
"testing"
"time"
"github.com/photoview/photoview/api/graphql/models"
"github.com/photoview/photoview/api/scanner"
"github.com/photoview/photoview/api/scanner/face_detection"
"github.com/photoview/photoview/api/test_utils"
"github.com/stretchr/testify/assert"
)
@ -42,6 +44,10 @@ func TestFullScan(t *testing.T) {
return
}
if !assert.NoError(t, face_detection.InitializeFaceDetector(db)) {
return
}
if !assert.NoError(t, scanner.AddUserToQueue(user)) {
return
}
@ -54,6 +60,32 @@ func TestFullScan(t *testing.T) {
return
}
assert.Equal(t, 10, len(all_media))
assert.Equal(t, 9, len(all_media))
var all_media_url []*models.MediaURL
if !assert.NoError(t, db.Find(&all_media_url).Error) {
return
}
assert.Equal(t, 18, len(all_media_url))
// Verify that faces was recognized
assert.Eventually(t, func() bool {
var all_face_groups []*models.FaceGroup
if !assert.NoError(t, db.Find(&all_face_groups).Error) {
return false
}
return len(all_face_groups) == 3
}, time.Second*5, time.Millisecond*500)
assert.Eventually(t, func() bool {
var all_image_faces []*models.ImageFace
if !assert.NoError(t, db.Find(&all_image_faces).Error) {
return false
}
return len(all_image_faces) == 6
}, time.Second*5, time.Millisecond*500)
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 292 KiB

View File

@ -32,12 +32,13 @@ func UnitTestRun(m *testing.M) int {
func IntegrationTestRun(m *testing.M) int {
flag.Parse()
if *integration_flags.Database {
_, file, _, ok := runtime.Caller(0)
if !ok {
log.Fatal("could not get runtime file path")
}
if *integration_flags.Database {
envPath := path.Join(path.Dir(file), "..", "testing.env")
if err := godotenv.Load(envPath); err != nil {
@ -45,6 +46,9 @@ func IntegrationTestRun(m *testing.M) int {
}
}
faceModelsPath := path.Join(path.Dir(file), "..", "data", "models")
utils.ConfigureTestFaceRecognitionModelsPath(faceModelsPath)
result := m.Run()
test_dbm.Close()

View File

@ -11,6 +11,7 @@ const (
EnvServeUI EnvironmentVariable = "PHOTOVIEW_SERVE_UI"
EnvUIPath EnvironmentVariable = "PHOTOVIEW_UI_PATH"
EnvMediaCachePath EnvironmentVariable = "PHOTOVIEW_MEDIA_CACHE"
EnvFaceRecognitionModelsPath EnvironmentVariable = "PHOTOVIEW_FACE_RECOGNITION_MODELS_PATH"
)
// Network related

View File

@ -5,6 +5,7 @@ import (
"fmt"
"log"
"math/big"
"path"
)
func GenerateToken() string {
@ -61,3 +62,21 @@ func MediaCachePath() string {
return photoCache
}
var test_face_recognition_models_path string = ""
func ConfigureTestFaceRecognitionModelsPath(path string) {
test_face_recognition_models_path = path
}
func FaceRecognitionModelsPath() string {
if test_face_recognition_models_path != "" {
return test_face_recognition_models_path
}
if EnvFaceRecognitionModelsPath.GetValue() == "" {
return path.Join("data", "models")
}
return EnvFaceRecognitionModelsPath.GetValue()
}