2024-07-01 02:41:41 +01:00
package user
import (
"context"
"fmt"
"time"
"donetick.com/core/config"
uModel "donetick.com/core/internal/user/model"
"donetick.com/core/logging"
"gorm.io/gorm"
)
type IUserRepository interface {
GetUserByUsername ( username string ) ( * uModel . User , error )
GetUser ( id int ) ( * uModel . User , error )
GetAllUsers ( ) ( [ ] * uModel . User , error )
CreateUser ( user * uModel . User ) error
UpdateUser ( user * uModel . User ) error
UpdateUserCircle ( userID , circleID int ) error
FindByEmail ( email string ) ( * uModel . User , error )
}
type UserRepository struct {
db * gorm . DB
isDonetickDotCom bool
}
func NewUserRepository ( db * gorm . DB , cfg * config . Config ) * UserRepository {
return & UserRepository { db , cfg . IsDoneTickDotCom }
}
func ( r * UserRepository ) GetAllUsers ( c context . Context , circleID int ) ( [ ] * uModel . User , error ) {
var users [ ] * uModel . User
if err := r . db . WithContext ( c ) . Where ( "circle_id = ?" , circleID ) . Find ( & users ) . Error ; err != nil {
return nil , err
}
return users , nil
}
func ( r * UserRepository ) GetAllUsersForSystemOnly ( c context . Context ) ( [ ] * uModel . User , error ) {
var users [ ] * uModel . User
if err := r . db . WithContext ( c ) . Find ( & users ) . Error ; err != nil {
return nil , err
}
return users , nil
}
func ( r * UserRepository ) CreateUser ( c context . Context , user * uModel . User ) ( * uModel . User , error ) {
if err := r . db . WithContext ( c ) . Save ( user ) . Error ; err != nil {
return nil , err
}
return user , nil
}
func ( r * UserRepository ) GetUserByUsername ( c context . Context , username string ) ( * uModel . User , error ) {
var user * uModel . User
if r . isDonetickDotCom {
if err := r . db . WithContext ( c ) . Table ( "users u" ) . Select ( "u.*, ss.status as subscription, ss.expired_at as expiration" ) . Joins ( "left join stripe_customers sc on sc.user_id = u.id " ) . Joins ( "left join stripe_subscriptions ss on sc.customer_id = ss.customer_id" ) . Where ( "username = ?" , username ) . First ( & user ) . Error ; err != nil {
return nil , err
}
} else {
2024-07-01 23:44:41 +01:00
if err := r . db . WithContext ( c ) . Table ( "users u" ) . Select ( "u.*, 'active' as subscription, '2999-12-31' as expiration" ) . Where ( "username = ?" , username ) . First ( & user ) . Error ; err != nil {
2024-07-01 02:41:41 +01:00
return nil , err
}
}
return user , nil
}
func ( r * UserRepository ) UpdateUser ( c context . Context , user * uModel . User ) error {
return r . db . WithContext ( c ) . Save ( user ) . Error
}
func ( r * UserRepository ) UpdateUserCircle ( c context . Context , userID , circleID int ) error {
return r . db . WithContext ( c ) . Model ( & uModel . User { } ) . Where ( "id = ?" , userID ) . Update ( "circle_id" , circleID ) . Error
}
func ( r * UserRepository ) FindByEmail ( c context . Context , email string ) ( * uModel . User , error ) {
var user * uModel . User
if err := r . db . WithContext ( c ) . Where ( "email = ?" , email ) . First ( & user ) . Error ; err != nil {
return nil , err
}
return user , nil
}
func ( r * UserRepository ) SetPasswordResetToken ( c context . Context , email , token string ) error {
// confirm user exists with email:
user , err := r . FindByEmail ( c , email )
if err != nil {
return err
}
// save new token:
if err := r . db . WithContext ( c ) . Model ( & uModel . UserPasswordReset { } ) . Save ( & uModel . UserPasswordReset {
UserID : user . ID ,
Token : token ,
Email : email ,
ExpirationDate : time . Now ( ) . UTC ( ) . Add ( time . Hour * 24 ) ,
} ) . Error ; err != nil {
return err
}
return nil
}
func ( r * UserRepository ) UpdatePasswordByToken ( ctx context . Context , email string , token string , password string ) error {
logger := logging . FromContext ( ctx )
logger . Debugw ( "account.db.UpdatePasswordByToken" , "email" , email )
upr := & uModel . UserPasswordReset {
Email : email ,
Token : token ,
}
result := r . db . WithContext ( ctx ) . Where ( "email = ?" , email ) . Where ( "token = ?" , token ) . Delete ( upr )
if result . RowsAffected <= 0 {
return fmt . Errorf ( "invalid token" )
}
// find account by email and update password:
chain := r . db . WithContext ( ctx ) . Model ( & uModel . User { } ) . Where ( "email = ?" , email ) . UpdateColumns ( map [ string ] interface { } { "password" : password } )
if chain . Error != nil {
return chain . Error
}
if chain . RowsAffected == 0 {
return fmt . Errorf ( "account not found" )
}
return nil
}
func ( r * UserRepository ) StoreAPIToken ( c context . Context , userID int , name string , tokenCode string ) ( * uModel . APIToken , error ) {
token := & uModel . APIToken {
UserID : userID ,
Name : name ,
Token : tokenCode ,
CreatedAt : time . Now ( ) . UTC ( ) ,
}
if err := r . db . WithContext ( c ) . Model ( & uModel . APIToken { } ) . Save (
token ) . Error ; err != nil {
return nil , err
}
return token , nil
}
func ( r * UserRepository ) GetUserByToken ( c context . Context , token string ) ( * uModel . User , error ) {
var user * uModel . User
if err := r . db . WithContext ( c ) . Table ( "users u" ) . Select ( "u.*" ) . Joins ( "left join api_tokens at on at.user_id = u.id" ) . Where ( "at.token = ?" , token ) . First ( & user ) . Error ; err != nil {
return nil , err
}
return user , nil
}
func ( r * UserRepository ) GetAllUserTokens ( c context . Context , userID int ) ( [ ] * uModel . APIToken , error ) {
var tokens [ ] * uModel . APIToken
if err := r . db . WithContext ( c ) . Where ( "user_id = ?" , userID ) . Find ( & tokens ) . Error ; err != nil {
return nil , err
}
return tokens , nil
}
func ( r * UserRepository ) DeleteAPIToken ( c context . Context , userID int , tokenID string ) error {
return r . db . WithContext ( c ) . Where ( "id = ? AND user_id = ?" , tokenID , userID ) . Delete ( & uModel . APIToken { } ) . Error
}