package repository import ( "context" "database/sql" "errors" "fmt" "time" "git.jamestombleson.com/jtom38/newsbot-api/internal/entity" "github.com/huandu/go-sqlbuilder" "golang.org/x/crypto/bcrypt" ) const ( usersTableName string = "users" ErrUserNotFound string = "requested user was not found" ) type Users interface { GetByName(ctx context.Context, name string) (entity.UserEntity, error) Create(ctx context.Context, name, password, sessionTOken, scope string) (int64, error) Update(ctx context.Context, id int, entity entity.UserEntity) error UpdatePassword(ctx context.Context, name, password string) error CheckUserHash(ctx context.Context, name, password string) error UpdateScopes(ctx context.Context, name, scope string) error UpdateSessionToken(ctx context.Context, name, sessionToken string) (int64, error) } // Creates a new instance of UserRepository with the bound sql func NewUserRepository(conn *sql.DB) userRepository { return userRepository{ connection: conn, } } type userRepository struct { connection *sql.DB } func (ur userRepository) GetByName(ctx context.Context, name string) (entity.UserEntity, error) { builder := sqlbuilder.NewSelectBuilder() builder.Select("*").From("users").Where( builder.E("Name", name), ) query, args := builder.Build() rows, err := ur.connection.QueryContext(ctx, query, args...) if err != nil { return entity.UserEntity{}, err } data := ur.processRows(rows) if len(data) == 0 { return entity.UserEntity{}, errors.New(ErrUserNotFound) } return data[0], nil } func (ur userRepository) Create(ctx context.Context, name, password, sessionToken, scope string) (int64, error) { passwordBytes := []byte(password) hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) if err != nil { return 0, err } dt := time.Now() queryBuilder := sqlbuilder.NewInsertBuilder() queryBuilder.InsertInto("users") queryBuilder.Cols("Name", "Hash", "UpdatedAt", "CreatedAt", "DeletedAt", "Scopes", "SessionToken") queryBuilder.Values(name, string(hash), dt, dt, time.Time{}, scope, sessionToken) query, args := queryBuilder.Build() _, err = ur.connection.ExecContext(ctx, query, args...) if err != nil { return 0, err } return 1, nil } func (ur userRepository) Update(ctx context.Context, id int, entity entity.UserEntity) error { return errors.New("not implemented") } func (ur userRepository) UpdatePassword(ctx context.Context, name, password string) error { _, err := ur.GetByName(ctx, name) if err != nil { return nil } queryBuilder := sqlbuilder.NewUpdateBuilder() queryBuilder.Update(usersTableName) //queryBuilder.Set return nil } func (ur userRepository) UpdateSessionToken(ctx context.Context, name, sessionToken string) (int64, error) { _, err := ur.GetByName(ctx, name) if err != nil { return 0, err } q := sqlbuilder.NewUpdateBuilder() q.Update(usersTableName) q.Set( q.Equal("SessionToken", sessionToken), ) q.Where( q.Equal("Name", name), ) query, args := q.Build() rowsUpdates, err := ur.connection.ExecContext(ctx, query, args...) if err != nil { return 0, err } return rowsUpdates.RowsAffected() } // If the hash matches what we have in the database, an error will not be returned. // If the user does not exist or the hash does not match, an error will be returned func (ur userRepository) CheckUserHash(ctx context.Context, name, password string) error { record, err := ur.GetByName(ctx, name) if err != nil { return err } err = bcrypt.CompareHashAndPassword([]byte(record.Hash), []byte(password)) if err != nil { return err } return nil } func (ur userRepository) UpdateScopes(ctx context.Context, name, scope string) error { builder := sqlbuilder.NewUpdateBuilder() builder.Update("users") builder.Set( builder.Assign("Scopes", scope), ) builder.Where( builder.Equal("Name", name), ) query, args := builder.Build() _, err := ur.connection.ExecContext(ctx, query, args...) if err != nil { return err } return nil } func (ur userRepository) processRows(rows *sql.Rows) []entity.UserEntity { items := []entity.UserEntity{} for rows.Next() { var id int64 var username string var hash string var createdAt time.Time var updatedAt time.Time var deletedAt sql.NullTime var scopes string var sessionToken string err := rows.Scan(&id, &createdAt, &updatedAt, &deletedAt, &username, &hash, &scopes, &sessionToken) if err != nil { fmt.Println(err) } item := entity.UserEntity{ ID: id, UpdatedAt: updatedAt, Username: username, Hash: hash, Scopes: scopes, CreatedAt: createdAt, SessionToken: sessionToken, } if deletedAt.Valid { item.DeletedAt = deletedAt.Time } items = append(items, item) } return items }