newsbot-api/internal/repository/users.go

193 lines
4.7 KiB
Go

package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"git.jamestombleson.com/jtom38/newsbot-api/internal/entity"
"github.com/huandu/go-sqlbuilder"
"golang.org/x/crypto/bcrypt"
)
const (
usersTableName string = "users"
ErrUserNotFound string = "requested user was not found"
)
type Users interface {
GetByName(ctx context.Context, name string) (entity.UserEntity, error)
Create(ctx context.Context, name, password, sessionTOken, scope string) (int64, error)
Update(ctx context.Context, id int, entity entity.UserEntity) error
UpdatePassword(ctx context.Context, name, password string) error
CheckUserHash(ctx context.Context, name, password string) error
UpdateScopes(ctx context.Context, name, scope string) error
UpdateSessionToken(ctx context.Context, name, sessionToken string) (int64, error)
}
// Creates a new instance of UserRepository with the bound sql
func NewUserRepository(conn *sql.DB) userRepository {
return userRepository{
connection: conn,
}
}
type userRepository struct {
connection *sql.DB
}
func (ur userRepository) GetByName(ctx context.Context, name string) (entity.UserEntity, error) {
builder := sqlbuilder.NewSelectBuilder()
builder.Select("*").From("users").Where(
builder.E("Name", name),
)
query, args := builder.Build()
rows, err := ur.connection.QueryContext(ctx, query, args...)
if err != nil {
return entity.UserEntity{}, err
}
data := ur.processRows(rows)
if len(data) == 0 {
return entity.UserEntity{}, errors.New(ErrUserNotFound)
}
return data[0], nil
}
func (ur userRepository) Create(ctx context.Context, name, password, sessionToken, scope string) (int64, error) {
passwordBytes := []byte(password)
hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
if err != nil {
return 0, err
}
dt := time.Now()
queryBuilder := sqlbuilder.NewInsertBuilder()
queryBuilder.InsertInto("users")
queryBuilder.Cols("Name", "Hash", "UpdatedAt", "CreatedAt", "DeletedAt", "Scopes", "SessionToken")
queryBuilder.Values(name, string(hash), dt, dt, time.Time{}, scope, sessionToken)
query, args := queryBuilder.Build()
_, err = ur.connection.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return 1, nil
}
func (ur userRepository) Update(ctx context.Context, id int, entity entity.UserEntity) error {
return errors.New("not implemented")
}
func (ur userRepository) UpdatePassword(ctx context.Context, name, password string) error {
_, err := ur.GetByName(ctx, name)
if err != nil {
return nil
}
queryBuilder := sqlbuilder.NewUpdateBuilder()
queryBuilder.Update(usersTableName)
//queryBuilder.Set
return nil
}
func (ur userRepository) UpdateSessionToken(ctx context.Context, name, sessionToken string) (int64, error) {
_, err := ur.GetByName(ctx, name)
if err != nil {
return 0, err
}
q := sqlbuilder.NewUpdateBuilder()
q.Update(usersTableName)
q.Set(
q.Equal("SessionToken", sessionToken),
)
q.Where(
q.Equal("Name", name),
)
query, args := q.Build()
rowsUpdates, err := ur.connection.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return rowsUpdates.RowsAffected()
}
// If the hash matches what we have in the database, an error will not be returned.
// If the user does not exist or the hash does not match, an error will be returned
func (ur userRepository) CheckUserHash(ctx context.Context, name, password string) error {
record, err := ur.GetByName(ctx, name)
if err != nil {
return err
}
err = bcrypt.CompareHashAndPassword([]byte(record.Hash), []byte(password))
if err != nil {
return err
}
return nil
}
func (ur userRepository) UpdateScopes(ctx context.Context, name, scope string) error {
builder := sqlbuilder.NewUpdateBuilder()
builder.Update("users")
builder.Set(
builder.Assign("Scopes", scope),
)
builder.Where(
builder.Equal("Name", name),
)
query, args := builder.Build()
_, err := ur.connection.ExecContext(ctx, query, args...)
if err != nil {
return err
}
return nil
}
func (ur userRepository) processRows(rows *sql.Rows) []entity.UserEntity {
items := []entity.UserEntity{}
for rows.Next() {
var id int64
var username string
var hash string
var createdAt time.Time
var updatedAt time.Time
var deletedAt sql.NullTime
var scopes string
var sessionToken string
err := rows.Scan(&id, &createdAt, &updatedAt, &deletedAt, &username, &hash, &scopes, &sessionToken)
if err != nil {
fmt.Println(err)
}
item := entity.UserEntity{
ID: id,
UpdatedAt: updatedAt,
Username: username,
Hash: hash,
Scopes: scopes,
CreatedAt: createdAt,
SessionToken: sessionToken,
}
if deletedAt.Valid {
item.DeletedAt = deletedAt.Time
}
items = append(items, item)
}
return items
}