Fix faces not getting scanned
- This fixes #344 - Add integration tests for face recognition - Properly check that the user own the queried album
This commit is contained in:
parent
d03923992c
commit
1029b61a4c
|
@ -36,6 +36,8 @@ models:
|
|||
resolver: true
|
||||
faces:
|
||||
resolver: true
|
||||
mediaType:
|
||||
resolver: true
|
||||
MediaURL:
|
||||
model: github.com/photoview/photoview/api/graphql/models.MediaURL
|
||||
MediaEXIF:
|
||||
|
@ -60,3 +62,5 @@ models:
|
|||
model: github.com/photoview/photoview/api/graphql/models.FaceRectangle
|
||||
SiteInfo:
|
||||
model: github.com/photoview/photoview/api/graphql/models.SiteInfo
|
||||
MediaType:
|
||||
model: github.com/photoview/photoview/api/graphql/models.MediaType
|
||||
|
|
|
@ -11993,13 +11993,19 @@ func (ec *executionContext) marshalNMediaDownload2ᚖgithubᚗcomᚋphotoviewᚋ
|
|||
}
|
||||
|
||||
func (ec *executionContext) unmarshalNMediaType2githubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaType(ctx context.Context, v interface{}) (models.MediaType, error) {
|
||||
var res models.MediaType
|
||||
err := res.UnmarshalGQL(v)
|
||||
tmp, err := graphql.UnmarshalString(v)
|
||||
res := models.MediaType(tmp)
|
||||
return res, graphql.ErrorOnPath(ctx, err)
|
||||
}
|
||||
|
||||
func (ec *executionContext) marshalNMediaType2githubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaType(ctx context.Context, sel ast.SelectionSet, v models.MediaType) graphql.Marshaler {
|
||||
return v
|
||||
res := graphql.MarshalString(string(v))
|
||||
if res == graphql.Null {
|
||||
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
|
||||
ec.Errorf(ctx, "must not be null")
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (ec *executionContext) marshalNMediaURL2ᚖgithubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaURL(ctx context.Context, sel ast.SelectionSet, v *models.MediaURL) graphql.Marshaler {
|
||||
|
|
|
@ -32,7 +32,10 @@ func (a *Album) BeforeSave(tx *gorm.DB) (err error) {
|
|||
// GetChildren performs a recursive query to get all the children of the album.
|
||||
// An optional filter can be provided that can be used to modify the query on the children.
|
||||
func (a *Album) GetChildren(db *gorm.DB, filter func(*gorm.DB) *gorm.DB) (children []*Album, err error) {
|
||||
// SELECT * FROM sub_albums
|
||||
return GetChildrenFromAlbums(db, filter, []int{a.ID})
|
||||
}
|
||||
|
||||
func GetChildrenFromAlbums(db *gorm.DB, filter func(*gorm.DB) *gorm.DB, albumIDs []int) (children []*Album, err error) {
|
||||
query := db.Model(&Album{}).Table("sub_albums")
|
||||
|
||||
if filter != nil {
|
||||
|
@ -41,13 +44,13 @@ func (a *Album) GetChildren(db *gorm.DB, filter func(*gorm.DB) *gorm.DB) (childr
|
|||
|
||||
err = db.Raw(`
|
||||
WITH recursive sub_albums AS (
|
||||
SELECT * FROM albums AS root WHERE id = ?
|
||||
SELECT * FROM albums AS root WHERE id IN (?)
|
||||
UNION ALL
|
||||
SELECT child.* FROM albums AS child JOIN sub_albums ON child.parent_album_id = sub_albums.id
|
||||
)
|
||||
|
||||
?
|
||||
`, a.ID, query).Find(&children).Error
|
||||
`, albumIDs, query).Find(&children).Error
|
||||
|
||||
return children, err
|
||||
}
|
||||
|
|
|
@ -121,47 +121,6 @@ func (e LanguageTranslation) MarshalGQL(w io.Writer) {
|
|||
fmt.Fprint(w, strconv.Quote(e.String()))
|
||||
}
|
||||
|
||||
type MediaType string
|
||||
|
||||
const (
|
||||
MediaTypePhoto MediaType = "Photo"
|
||||
MediaTypeVideo MediaType = "Video"
|
||||
)
|
||||
|
||||
var AllMediaType = []MediaType{
|
||||
MediaTypePhoto,
|
||||
MediaTypeVideo,
|
||||
}
|
||||
|
||||
func (e MediaType) IsValid() bool {
|
||||
switch e {
|
||||
case MediaTypePhoto, MediaTypeVideo:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e MediaType) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
func (e *MediaType) UnmarshalGQL(v interface{}) error {
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("enums must be strings")
|
||||
}
|
||||
|
||||
*e = MediaType(str)
|
||||
if !e.IsValid() {
|
||||
return fmt.Errorf("%s is not a valid MediaType", str)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e MediaType) MarshalGQL(w io.Writer) {
|
||||
fmt.Fprint(w, strconv.Quote(e.String()))
|
||||
}
|
||||
|
||||
type NotificationType string
|
||||
|
||||
const (
|
||||
|
|
|
@ -42,30 +42,19 @@ func (m *Media) BeforeSave(tx *gorm.DB) error {
|
|||
// Update path hash
|
||||
m.PathHash = MD5Hash(m.Path)
|
||||
|
||||
// Save media type as lowercase for better compatibility
|
||||
m.Type = MediaType(strings.ToLower(string(m.Type)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Media) AfterFind(tx *gorm.DB) error {
|
||||
type MediaType string
|
||||
|
||||
// Convert lowercased media type back
|
||||
lowercasedType := strings.ToLower(string(m.Type))
|
||||
foundType := false
|
||||
for _, t := range AllMediaType {
|
||||
if strings.ToLower(string(t)) == lowercasedType {
|
||||
m.Type = t
|
||||
foundType = true
|
||||
break
|
||||
}
|
||||
}
|
||||
const (
|
||||
MediaTypePhoto MediaType = "photo"
|
||||
MediaTypeVideo MediaType = "video"
|
||||
)
|
||||
|
||||
if foundType == false {
|
||||
return errors.New(fmt.Sprintf("Failed to parse media from DB: Invalid media type: %s", m.Type))
|
||||
}
|
||||
|
||||
return nil
|
||||
var AllMediaType = []MediaType{
|
||||
MediaTypePhoto,
|
||||
MediaTypeVideo,
|
||||
}
|
||||
|
||||
type MediaPurpose string
|
||||
|
|
|
@ -173,11 +173,29 @@ func (user *User) FillAlbums(db *gorm.DB) error {
|
|||
}
|
||||
|
||||
func (user *User) OwnsAlbum(db *gorm.DB, album *Album) (bool, error) {
|
||||
// TODO: Implement this
|
||||
panic("not implemented")
|
||||
|
||||
if err := user.FillAlbums(db); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
albumIDs := make([]int, 0)
|
||||
for _, a := range user.Albums {
|
||||
albumIDs = append(albumIDs, a.ID)
|
||||
}
|
||||
|
||||
filter := func(query *gorm.DB) *gorm.DB {
|
||||
return query.Where("id = ?", album.ID)
|
||||
}
|
||||
|
||||
ownedAlbum, err := GetChildrenFromAlbums(db, filter, albumIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return len(ownedAlbum) > 0, nil
|
||||
}
|
||||
|
||||
func (user *User) OwnsMedia(db *gorm.DB, media *Media) (bool, error) {
|
||||
// TODO: implement this
|
||||
panic("not implemented")
|
||||
}
|
||||
// func (user *User) OwnsMedia(db *gorm.DB, media *Media) (bool, error) {
|
||||
// // TODO: implement this
|
||||
// panic("not implemented")
|
||||
// }
|
||||
|
|
|
@ -69,15 +69,6 @@ func (r faceGroupResolver) ImageFaces(ctx context.Context, obj *models.FaceGroup
|
|||
return nil, err
|
||||
}
|
||||
|
||||
r.Database.Transaction(func(tx *gorm.DB) error {
|
||||
for i := range imageFaces {
|
||||
if err := imageFaces[i].Media.AfterFind(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return imageFaces, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolvers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
api "github.com/photoview/photoview/api/graphql"
|
||||
"github.com/photoview/photoview/api/graphql/auth"
|
||||
|
@ -108,6 +109,11 @@ func (r *Resolver) Media() api.MediaResolver {
|
|||
return &mediaResolver{r}
|
||||
}
|
||||
|
||||
func (r *mediaResolver) MediaType(ctx context.Context, media *models.Media) (*models.MediaType, error) {
|
||||
formattedType := models.MediaType(strings.Title(string(media.Type)))
|
||||
return &formattedType, nil
|
||||
}
|
||||
|
||||
func (r *mediaResolver) Shares(ctx context.Context, media *models.Media) ([]*models.ShareToken, error) {
|
||||
var shareTokens []*models.ShareToken
|
||||
if err := r.Database.Where("media_id = ?", media.ID).Find(&shareTokens).Error; err != nil {
|
||||
|
|
|
@ -2,11 +2,11 @@ package face_detection
|
|||
|
||||
import (
|
||||
"log"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/Kagami/go-face"
|
||||
"github.com/photoview/photoview/api/graphql/models"
|
||||
"github.com/photoview/photoview/api/utils"
|
||||
"github.com/pkg/errors"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
@ -25,7 +25,7 @@ func InitializeFaceDetector(db *gorm.DB) error {
|
|||
|
||||
log.Println("Initializing face detector")
|
||||
|
||||
rec, err := face.NewRecognizer(filepath.Join("data", "models"))
|
||||
rec, err := face.NewRecognizer(utils.FaceRecognitionModelsPath())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "initialize facedetect recognizer")
|
||||
}
|
||||
|
@ -107,6 +107,8 @@ func (fd *FaceDetector) DetectFaces(db *gorm.DB, media *models.Media) error {
|
|||
return err
|
||||
}
|
||||
|
||||
log.Printf("Face thumb path: %v %v\n", thumbnailPath, fd)
|
||||
|
||||
fd.mutex.Lock()
|
||||
faces, err := fd.rec.RecognizeFile(thumbnailPath)
|
||||
fd.mutex.Unlock()
|
||||
|
|
|
@ -3,9 +3,11 @@ package scanner_test
|
|||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/photoview/photoview/api/graphql/models"
|
||||
"github.com/photoview/photoview/api/scanner"
|
||||
"github.com/photoview/photoview/api/scanner/face_detection"
|
||||
"github.com/photoview/photoview/api/test_utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -42,6 +44,10 @@ func TestFullScan(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
if !assert.NoError(t, face_detection.InitializeFaceDetector(db)) {
|
||||
return
|
||||
}
|
||||
|
||||
if !assert.NoError(t, scanner.AddUserToQueue(user)) {
|
||||
return
|
||||
}
|
||||
|
@ -54,6 +60,32 @@ func TestFullScan(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, len(all_media))
|
||||
assert.Equal(t, 9, len(all_media))
|
||||
|
||||
var all_media_url []*models.MediaURL
|
||||
if !assert.NoError(t, db.Find(&all_media_url).Error) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, 18, len(all_media_url))
|
||||
|
||||
// Verify that faces was recognized
|
||||
assert.Eventually(t, func() bool {
|
||||
var all_face_groups []*models.FaceGroup
|
||||
if !assert.NoError(t, db.Find(&all_face_groups).Error) {
|
||||
return false
|
||||
}
|
||||
|
||||
return len(all_face_groups) == 3
|
||||
}, time.Second*5, time.Millisecond*500)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
var all_image_faces []*models.ImageFace
|
||||
if !assert.NoError(t, db.Find(&all_image_faces).Error) {
|
||||
return false
|
||||
}
|
||||
|
||||
return len(all_image_faces) == 6
|
||||
}, time.Second*5, time.Millisecond*500)
|
||||
|
||||
}
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 292 KiB |
|
@ -32,11 +32,12 @@ func UnitTestRun(m *testing.M) int {
|
|||
func IntegrationTestRun(m *testing.M) int {
|
||||
flag.Parse()
|
||||
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
log.Fatal("could not get runtime file path")
|
||||
}
|
||||
|
||||
if *integration_flags.Database {
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
log.Fatal("could not get runtime file path")
|
||||
}
|
||||
|
||||
envPath := path.Join(path.Dir(file), "..", "testing.env")
|
||||
|
||||
|
@ -45,6 +46,9 @@ func IntegrationTestRun(m *testing.M) int {
|
|||
}
|
||||
}
|
||||
|
||||
faceModelsPath := path.Join(path.Dir(file), "..", "data", "models")
|
||||
utils.ConfigureTestFaceRecognitionModelsPath(faceModelsPath)
|
||||
|
||||
result := m.Run()
|
||||
|
||||
test_dbm.Close()
|
||||
|
|
|
@ -7,10 +7,11 @@ type EnvironmentVariable string
|
|||
|
||||
// General options
|
||||
const (
|
||||
EnvDevelopmentMode EnvironmentVariable = "PHOTOVIEW_DEVELOPMENT_MODE"
|
||||
EnvServeUI EnvironmentVariable = "PHOTOVIEW_SERVE_UI"
|
||||
EnvUIPath EnvironmentVariable = "PHOTOVIEW_UI_PATH"
|
||||
EnvMediaCachePath EnvironmentVariable = "PHOTOVIEW_MEDIA_CACHE"
|
||||
EnvDevelopmentMode EnvironmentVariable = "PHOTOVIEW_DEVELOPMENT_MODE"
|
||||
EnvServeUI EnvironmentVariable = "PHOTOVIEW_SERVE_UI"
|
||||
EnvUIPath EnvironmentVariable = "PHOTOVIEW_UI_PATH"
|
||||
EnvMediaCachePath EnvironmentVariable = "PHOTOVIEW_MEDIA_CACHE"
|
||||
EnvFaceRecognitionModelsPath EnvironmentVariable = "PHOTOVIEW_FACE_RECOGNITION_MODELS_PATH"
|
||||
)
|
||||
|
||||
// Network related
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"path"
|
||||
)
|
||||
|
||||
func GenerateToken() string {
|
||||
|
@ -61,3 +62,21 @@ func MediaCachePath() string {
|
|||
|
||||
return photoCache
|
||||
}
|
||||
|
||||
var test_face_recognition_models_path string = ""
|
||||
|
||||
func ConfigureTestFaceRecognitionModelsPath(path string) {
|
||||
test_face_recognition_models_path = path
|
||||
}
|
||||
|
||||
func FaceRecognitionModelsPath() string {
|
||||
if test_face_recognition_models_path != "" {
|
||||
return test_face_recognition_models_path
|
||||
}
|
||||
|
||||
if EnvFaceRecognitionModelsPath.GetValue() == "" {
|
||||
return path.Join("data", "models")
|
||||
}
|
||||
|
||||
return EnvFaceRecognitionModelsPath.GetValue()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue