repositories now use context and have interfaces exposed

This commit is contained in:
James Tombleson 2024-04-28 11:39:25 -07:00
parent 7227744621
commit 9586c6a544
6 changed files with 76 additions and 34 deletions

View File

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -12,9 +13,19 @@ import (
const ( const (
ArticleOrderByPublishDateDesc = "pubDate desc" ArticleOrderByPublishDateDesc = "pubDate desc"
ArticleOrderByPublishDatAsc = "pubDate asc" ArticleOrderByPublishDateAsc = "pubDate asc"
) )
type ArticlesRepo interface {
GetById(ctx context.Context, id int64) (domain.ArticleEntity, error)
GetByUrl(ctx context.Context, url string) (domain.ArticleEntity, error)
ListTop(ctx context.Context, limit int) ([]domain.ArticleEntity, error)
ListByPage(ctx context.Context, page, limit int) ([]domain.ArticleEntity, error)
ListByPublishDate(ctx context.Context, page, limit int, orderBy string) ([]domain.ArticleEntity, error)
ListBySource(ctx context.Context, page, limit, sourceId int, orderBy string) ([]domain.ArticleEntity, error)
Create(ctx context.Context, sourceId int64, tags, title, url, thumbnailUrl, description, authorName, authorImageUrl string, pubDate time.Time, isVideo bool) (int64, error)
}
type ArticleRepository struct { type ArticleRepository struct {
conn *sql.DB conn *sql.DB
defaultLimit int defaultLimit int
@ -29,7 +40,7 @@ func NewArticleRepository(conn *sql.DB) ArticleRepository {
} }
} }
func (ar ArticleRepository) GetById(id int64) (domain.ArticleEntity, error) { func (ar ArticleRepository) GetById(ctx context.Context, id int64) (domain.ArticleEntity, error) {
builder := sqlbuilder.NewSelectBuilder() builder := sqlbuilder.NewSelectBuilder()
builder.Select("*") builder.Select("*")
builder.From("articles").Where( builder.From("articles").Where(
@ -38,7 +49,7 @@ func (ar ArticleRepository) GetById(id int64) (domain.ArticleEntity, error) {
builder.Limit(1) builder.Limit(1)
query, args := builder.Build() query, args := builder.Build()
rows, err := ar.conn.Query(query, args...) rows, err := ar.conn.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return domain.ArticleEntity{}, err return domain.ArticleEntity{}, err
} }
@ -51,7 +62,7 @@ func (ar ArticleRepository) GetById(id int64) (domain.ArticleEntity, error) {
return data[0], nil return data[0], nil
} }
func (ar ArticleRepository) GetByUrl(url string) (domain.ArticleEntity, error) { func (ar ArticleRepository) GetByUrl(ctx context.Context, url string) (domain.ArticleEntity, error) {
builder := sqlbuilder.NewSelectBuilder() builder := sqlbuilder.NewSelectBuilder()
builder.Select("*") builder.Select("*")
builder.From("articles").Where( builder.From("articles").Where(
@ -60,7 +71,7 @@ func (ar ArticleRepository) GetByUrl(url string) (domain.ArticleEntity, error) {
builder.Limit(1) builder.Limit(1)
query, args := builder.Build() query, args := builder.Build()
rows, err := ar.conn.Query(query, args...) rows, err := ar.conn.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return domain.ArticleEntity{}, err return domain.ArticleEntity{}, err
} }
@ -73,14 +84,14 @@ func (ar ArticleRepository) GetByUrl(url string) (domain.ArticleEntity, error) {
return data[0], nil return data[0], nil
} }
func (ar ArticleRepository) ListTop(limit int) ([]domain.ArticleEntity, error) { func (ar ArticleRepository) ListTop(ctx context.Context, limit int) ([]domain.ArticleEntity, error) {
builder := sqlbuilder.NewSelectBuilder() builder := sqlbuilder.NewSelectBuilder()
builder.Select("*") builder.Select("*")
builder.From("articles") builder.From("articles")
builder.Limit(limit) builder.Limit(limit)
query, args := builder.Build() query, args := builder.Build()
rows, err := ar.conn.Query(query, args...) rows, err := ar.conn.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return []domain.ArticleEntity{}, err return []domain.ArticleEntity{}, err
} }
@ -93,16 +104,16 @@ func (ar ArticleRepository) ListTop(limit int) ([]domain.ArticleEntity, error) {
return data, nil return data, nil
} }
func (ar ArticleRepository) ListByPage(page, limit int) ([]domain.ArticleEntity, error) { func (ar ArticleRepository) ListByPage(ctx context.Context, page, limit int) ([]domain.ArticleEntity, error) {
builder := sqlbuilder.NewSelectBuilder() builder := sqlbuilder.NewSelectBuilder()
builder.Select("*") builder.Select("*")
builder.From("articles") builder.From("articles")
builder.OrderBy("pubdate desc") builder.OrderBy(ArticleOrderByPublishDateDesc)
builder.Offset(page * limit) builder.Offset(page * limit)
builder.Limit(limit) builder.Limit(limit)
query, args := builder.Build() query, args := builder.Build()
rows, err := ar.conn.Query(query, args...) rows, err := ar.conn.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return []domain.ArticleEntity{}, err return []domain.ArticleEntity{}, err
} }
@ -115,7 +126,7 @@ func (ar ArticleRepository) ListByPage(page, limit int) ([]domain.ArticleEntity,
return data, nil return data, nil
} }
func (ar ArticleRepository) ListByPublishDate(page, limit int, orderBy string) ([]domain.ArticleEntity, error) { func (ar ArticleRepository) ListByPublishDate(ctx context.Context, page, limit int, orderBy string) ([]domain.ArticleEntity, error) {
builder := sqlbuilder.NewSelectBuilder() builder := sqlbuilder.NewSelectBuilder()
builder.Select("*") builder.Select("*")
builder.From("articles") builder.From("articles")
@ -126,7 +137,7 @@ func (ar ArticleRepository) ListByPublishDate(page, limit int, orderBy string) (
builder.Limit(limit) builder.Limit(limit)
query, args := builder.Build() query, args := builder.Build()
rows, err := ar.conn.Query(query, args...) rows, err := ar.conn.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return []domain.ArticleEntity{}, err return []domain.ArticleEntity{}, err
} }
@ -138,7 +149,7 @@ func (ar ArticleRepository) ListByPublishDate(page, limit int, orderBy string) (
return data, nil return data, nil
} }
func (ar ArticleRepository) ListBySource(page, limit int, orderBy string) ([]domain.ArticleEntity, error) { func (ar ArticleRepository) ListBySource(ctx context.Context, page, limit, sourceId int, orderBy string) ([]domain.ArticleEntity, error) {
builder := sqlbuilder.NewSelectBuilder() builder := sqlbuilder.NewSelectBuilder()
builder.Select("*") builder.Select("*")
builder.From("articles") builder.From("articles")
@ -146,11 +157,14 @@ func (ar ArticleRepository) ListBySource(page, limit int, orderBy string) ([]dom
if orderBy != "" { if orderBy != "" {
builder.OrderBy(orderBy) builder.OrderBy(orderBy)
} }
builder.Where(
builder.Equal("SourceId", sourceId),
)
builder.Offset(50) builder.Offset(50)
builder.Limit(page * limit) builder.Limit(page * limit)
query, args := builder.Build() query, args := builder.Build()
rows, err := ar.conn.Query(query, args...) rows, err := ar.conn.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return []domain.ArticleEntity{}, err return []domain.ArticleEntity{}, err
} }
@ -162,7 +176,7 @@ func (ar ArticleRepository) ListBySource(page, limit int, orderBy string) ([]dom
return data, nil return data, nil
} }
func (ar ArticleRepository) Create(sourceId int64, tags, title, url, thumbnailUrl, description, authorName, authorImageUrl string, pubDate time.Time, isVideo bool) (int64, error) { func (ar ArticleRepository) Create(ctx context.Context, sourceId int64, tags, title, url, thumbnailUrl, description, authorName, authorImageUrl string, pubDate time.Time, isVideo bool) (int64, error) {
dt := time.Now() dt := time.Now()
queryBuilder := sqlbuilder.NewInsertBuilder() queryBuilder := sqlbuilder.NewInsertBuilder()
queryBuilder.InsertInto("articles") queryBuilder.InsertInto("articles")
@ -170,7 +184,7 @@ func (ar ArticleRepository) Create(sourceId int64, tags, title, url, thumbnailUr
queryBuilder.Values(dt, dt, timeZero, sourceId, tags, title, url, pubDate, isVideo, thumbnailUrl, description, authorName, authorImageUrl) queryBuilder.Values(dt, dt, timeZero, sourceId, tags, title, url, pubDate, isVideo, thumbnailUrl, description, authorName, authorImageUrl)
query, args := queryBuilder.Build() query, args := queryBuilder.Build()
_, err := ar.conn.Exec(query, args...) _, err := ar.conn.ExecContext(ctx, query, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -1,6 +1,7 @@
package repository_test package repository_test
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -21,7 +22,7 @@ func TestCreateArticle(t *testing.T) {
defer db.Close() defer db.Close()
r := repository.NewArticleRepository(db) r := repository.NewArticleRepository(db)
created, err := r.Create(1, "", "unit test", articleFakeDotCom, "", "testing", "", "", time.Now(), false) created, err := r.Create(context.Background(), 1, "", "unit test", articleFakeDotCom, "", "testing", "", "", time.Now(), false)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
t.FailNow() t.FailNow()
@ -48,7 +49,7 @@ func TestArticleByUrl(t *testing.T) {
t.FailNow() t.FailNow()
} }
article, err := r.GetByUrl(articleFakeDotCom) article, err := r.GetByUrl(context.Background(), articleFakeDotCom)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
t.FailNow() t.FailNow()
@ -73,7 +74,7 @@ func TestPullingMultipleArticlesWithLimit(t *testing.T) {
insertFakeArticles(r, "u3", 0) insertFakeArticles(r, "u3", 0)
insertFakeArticles(r, "u4", 0) insertFakeArticles(r, "u4", 0)
items, err := r.ListTop(3) items, err := r.ListTop(context.Background(), 3)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
t.FailNow() t.FailNow()
@ -98,7 +99,7 @@ func TestPullingMultipleArticlesWithPaging(t *testing.T) {
insertFakeArticles(r, "u3", 0) insertFakeArticles(r, "u3", 0)
insertFakeArticles(r, "u4", 0) insertFakeArticles(r, "u4", 0)
items, err := r.ListByPage(2, 1) items, err := r.ListByPage(context.Background(), 2, 1)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
t.FailNow() t.FailNow()
@ -125,7 +126,7 @@ func TestPullingByPublishDate(t *testing.T) {
insertFakeArticles(r, "u1", -1) insertFakeArticles(r, "u1", -1)
insertFakeArticles(r, "u1", -2) insertFakeArticles(r, "u1", -2)
items, err := r.ListByPublishDate(0, 2, repository.ArticleOrderByPublishDateDesc) items, err := r.ListByPublishDate(context.Background(), 0, 2, repository.ArticleOrderByPublishDateDesc)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
t.FailNow() t.FailNow()
@ -147,7 +148,7 @@ func TestPullingByPublishDate(t *testing.T) {
func insertFakeArticles(r repository.ArticleRepository, title string, daysOld int) error { func insertFakeArticles(r repository.ArticleRepository, title string, daysOld int) error {
pubDate := time.Now().AddDate(0,0, daysOld) pubDate := time.Now().AddDate(0,0, daysOld)
_, err := r.Create(1, "", title, articleFakeDotCom, "", "testing", "", "", pubDate, false) _, err := r.Create(context.Background(), 1, "", title, articleFakeDotCom, "", "testing", "", "", pubDate, false)
if err != nil { if err != nil {
return err return err
} }

View File

@ -9,6 +9,19 @@ import (
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
) )
type DiscordWebHookRepo interface{
Create(ctx context.Context, url, server, channel string, enabled bool) (int64, error)
Enable(ctx context.Context, id int64) (int64, error)
Disable(ctx context.Context, id int64) (int64, error)
SoftDelete(ctx context.Context, id int64) (int64, error)
Restore(ctx context.Context, id int64) (int64, error)
Delete(ctx context.Context, id int64) (int64, error)
GetById(ctx context.Context, id int64) (domain.DiscordWebHookEntity, error)
GetByUrl(ctx context.Context, url string) (domain.DiscordWebHookEntity, error)
ListByServerName(ctx context.Context, name string) ([]domain.DiscordWebHookEntity, error)
ListByServerAndChannel(ctx context.Context, server, channel string) ([]domain.DiscordWebHookEntity, error)
}
type discordWebHookRepository struct { type discordWebHookRepository struct {
conn *sql.DB conn *sql.DB
} }

View File

@ -14,7 +14,7 @@ const (
refreshTokenTableName = "RefreshTokens" refreshTokenTableName = "RefreshTokens"
) )
type RefreshTokenTable interface { type RefreshToken interface {
Create(username string, token string) (int64, error) Create(username string, token string) (int64, error)
GetByUsername(name string) (domain.RefreshTokenEntity, error) GetByUsername(name string) (domain.RefreshTokenEntity, error)
DeleteById(id int64) (int64, error) DeleteById(id int64) (int64, error)

View File

@ -9,6 +9,20 @@ import (
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
) )
type Sources interface {
Create(ctx context.Context, source, displayName, url, tags string, enabled bool) (int64, error)
GetById(ctx context.Context, id int64) (domain.SourceEntity, error)
GetByDisplayName(ctx context.Context, displayName string) (domain.SourceEntity, error)
GetBySource(ctx context.Context, source string) (domain.SourceEntity, error)
List(ctx context.Context, page, limit int) ([]domain.SourceEntity, error)
ListBySource(ctx context.Context, page, limit int, source string) ([]domain.SourceEntity, error)
Enable(ctx context.Context, id int64) (int64, error)
Disable(ctx context.Context, id int64) (int64, error)
SoftDelete(ctx context.Context, id int64) (int64, error)
Restore(ctx context.Context, id int64) (int64, error)
Delete(ctx context.Context, id int64) (int64, error)
}
type sourceRepository struct { type sourceRepository struct {
conn *sql.DB conn *sql.DB
} }

View File

@ -17,7 +17,7 @@ const (
ErrUserNotFound string = "requested user was not found" ErrUserNotFound string = "requested user was not found"
) )
type IUserTable interface { type Users interface {
GetByName(name string) (domain.UserEntity, error) GetByName(name string) (domain.UserEntity, error)
Create(name, password, scope string) (int64, error) Create(name, password, scope string) (int64, error)
Update(id int, entity domain.UserEntity) error Update(id int, entity domain.UserEntity) error
@ -27,17 +27,17 @@ type IUserTable interface {
} }
// Creates a new instance of UserRepository with the bound sql // Creates a new instance of UserRepository with the bound sql
func NewUserRepository(conn *sql.DB) UserRepository { func NewUserRepository(conn *sql.DB) userRepository {
return UserRepository{ return userRepository{
connection: conn, connection: conn,
} }
} }
type UserRepository struct { type userRepository struct {
connection *sql.DB connection *sql.DB
} }
func (ur UserRepository) GetByName(name string) (domain.UserEntity, error) { func (ur userRepository) GetByName(name string) (domain.UserEntity, error) {
builder := sqlbuilder.NewSelectBuilder() builder := sqlbuilder.NewSelectBuilder()
builder.Select("*").From("users").Where( builder.Select("*").From("users").Where(
builder.E("Name", name), builder.E("Name", name),
@ -57,7 +57,7 @@ func (ur UserRepository) GetByName(name string) (domain.UserEntity, error) {
return data[0], nil return data[0], nil
} }
func (ur UserRepository) Create(name, password, scope string) (int64, error) { func (ur userRepository) Create(name, password, scope string) (int64, error) {
passwordBytes := []byte(password) passwordBytes := []byte(password)
hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
if err != nil { if err != nil {
@ -79,11 +79,11 @@ func (ur UserRepository) Create(name, password, scope string) (int64, error) {
return 1, nil return 1, nil
} }
func (ur UserRepository) Update(id int, entity domain.UserEntity) error { func (ur userRepository) Update(id int, entity domain.UserEntity) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (ur UserRepository) UpdatePassword(name, password string) error { func (ur userRepository) UpdatePassword(name, password string) error {
_, err := ur.GetByName(name) _, err := ur.GetByName(name)
if err != nil { if err != nil {
return nil return nil
@ -97,7 +97,7 @@ func (ur UserRepository) UpdatePassword(name, password string) error {
// If the hash matches what we have in the database, an error will not be returned. // If the hash matches what we have in the database, an error will not be returned.
// If the user does not exist or the hash does not match, an error will be returned // If the user does not exist or the hash does not match, an error will be returned
func (ur UserRepository) CheckUserHash(name, password string) error { func (ur userRepository) CheckUserHash(name, password string) error {
record, err := ur.GetByName(name) record, err := ur.GetByName(name)
if err != nil { if err != nil {
return err return err
@ -111,7 +111,7 @@ func (ur UserRepository) CheckUserHash(name, password string) error {
return nil return nil
} }
func (ur UserRepository) UpdateScopes(name, scope string) error { func (ur userRepository) UpdateScopes(name, scope string) error {
builder := sqlbuilder.NewUpdateBuilder() builder := sqlbuilder.NewUpdateBuilder()
builder.Update("users") builder.Update("users")
builder.Set( builder.Set(
@ -129,7 +129,7 @@ func (ur UserRepository) UpdateScopes(name, scope string) error {
return nil return nil
} }
func (ur UserRepository) processRows(rows *sql.Rows) []domain.UserEntity { func (ur userRepository) processRows(rows *sql.Rows) []domain.UserEntity {
items := []domain.UserEntity{} items := []domain.UserEntity{}
for rows.Next() { for rows.Next() {