package repository import ( "context" "database/sql" "time" "git.jamestombleson.com/jtom38/newsbot-api/internal/domain" "github.com/huandu/go-sqlbuilder" ) type Sources interface { Create(ctx context.Context, source, displayName, url, tags string, enabled bool) (int64, error) GetById(ctx context.Context, id int64) (domain.SourceEntity, error) GetByDisplayName(ctx context.Context, displayName string) (domain.SourceEntity, error) GetBySource(ctx context.Context, source string) (domain.SourceEntity, error) List(ctx context.Context, page, limit int) ([]domain.SourceEntity, error) ListBySource(ctx context.Context, page, limit int, source string) ([]domain.SourceEntity, error) Enable(ctx context.Context, id int64) (int64, error) Disable(ctx context.Context, id int64) (int64, error) SoftDelete(ctx context.Context, id int64) (int64, error) Restore(ctx context.Context, id int64) (int64, error) Delete(ctx context.Context, id int64) (int64, error) } type sourceRepository struct { conn *sql.DB } func NewSourceRepository(conn *sql.DB) sourceRepository { return sourceRepository{ conn: conn, } } func (r sourceRepository) Create(ctx context.Context, source, displayName, url, tags string, enabled bool) (int64, error) { dt := time.Now() queryBuilder := sqlbuilder.NewInsertBuilder() queryBuilder.InsertInto("Sources") queryBuilder.Cols("CreatedAt", "UpdatedAt", "DeletedAt", "DisplayName", "Source", "Url", "Tags", "Enabled") queryBuilder.Values(dt, dt, timeZero, displayName, source, url, tags, enabled) query, args := queryBuilder.Build() _, err := r.conn.ExecContext(ctx, query, args...) if err != nil { return 0, err } return 1, nil } func (r sourceRepository) GetById(ctx context.Context, id int64) (domain.SourceEntity, error) { b := sqlbuilder.NewSelectBuilder() b.Select("*") b.From("Sources").Where( b.Equal("Id", id), ) b.Limit(1) query, args := b.Build() rows, err := r.conn.QueryContext(ctx, query, args...) if err != nil { return domain.SourceEntity{}, err } data, err := r.processRows(rows) if len(data) == 0 { return domain.SourceEntity{}, err } return data[0], nil } func (r sourceRepository) GetByDisplayName(ctx context.Context, displayName string) (domain.SourceEntity, error) { b := sqlbuilder.NewSelectBuilder() b.Select("*") b.From("Sources").Where( b.Equal("DisplayName", displayName), ) b.Limit(1) query, args := b.Build() rows, err := r.conn.QueryContext(ctx, query, args...) if err != nil { return domain.SourceEntity{}, err } data, err := r.processRows(rows) if len(data) == 0 { return domain.SourceEntity{}, err } return data[0], nil } func (r sourceRepository) GetBySource(ctx context.Context, source string) (domain.SourceEntity, error) { b := sqlbuilder.NewSelectBuilder() b.Select("*") b.From("Sources").Where( b.Equal("Source", source), ) b.Limit(1) query, args := b.Build() rows, err := r.conn.QueryContext(ctx, query, args...) if err != nil { return domain.SourceEntity{}, err } data, err := r.processRows(rows) if len(data) == 0 { return domain.SourceEntity{}, err } return data[0], nil } func (r sourceRepository) List(ctx context.Context, page, limit int) ([]domain.SourceEntity, error) { builder := sqlbuilder.NewSelectBuilder() builder.Select("*") builder.From("Sources") builder.Offset(page * limit) builder.Limit(limit) query, args := builder.Build() rows, err := r.conn.QueryContext(ctx, query, args...) if err != nil { return []domain.SourceEntity{}, err } data, err := r.processRows(rows) if len(data) == 0 { return []domain.SourceEntity{}, err } return data, nil } func (r sourceRepository) ListBySource(ctx context.Context, page, limit int, source string) ([]domain.SourceEntity, error) { builder := sqlbuilder.NewSelectBuilder() builder.Select("*") builder.From("Sources") builder.Where( builder.Equal("Source", source), ) builder.Offset(page * limit) builder.Limit(limit) query, args := builder.Build() rows, err := r.conn.QueryContext(ctx, query, args...) if err != nil { return []domain.SourceEntity{}, err } data, err := r.processRows(rows) if len(data) == 0 { return []domain.SourceEntity{}, err } return data, nil } func (r sourceRepository) Enable(ctx context.Context, id int64) (int64, error) { b := sqlbuilder.NewUpdateBuilder() b.Update("Sources") b.Set( b.Assign("Enabled", true), b.Assign("UpdatedAt", time.Now()), ) b.Where( b.Equal("Id", id), ) query, args := b.Build() _, err := r.conn.ExecContext(ctx, query, args...) if err != nil { return 0, err } return 1, nil } func (r sourceRepository) Disable(ctx context.Context, id int64) (int64, error) { b := sqlbuilder.NewUpdateBuilder() b.Update("Sources") b.Set( b.Assign("Enabled", false), b.Assign("UpdatedAt", time.Now()), ) b.Where( b.Equal("Id", id), ) query, args := b.Build() _, err := r.conn.ExecContext(ctx, query, args...) if err != nil { return 0, err } return 1, nil } func (r sourceRepository) SoftDelete(ctx context.Context, id int64) (int64, error) { return softDeleteRow(ctx, r.conn, "Sources", id) } func (r sourceRepository) Restore(ctx context.Context, id int64) (int64, error) { return restoreRow(ctx, r.conn, "Sources", id) } func (r sourceRepository) Delete(ctx context.Context, id int64) (int64, error) { return deleteFromTable(ctx, r.conn, "Sources", id) } func (r sourceRepository) processRows(rows *sql.Rows) ([]domain.SourceEntity, error) { items := []domain.SourceEntity{} for rows.Next() { var id int64 var createdAt time.Time var updatedAt time.Time var deletedAt time.Time var displayName string var source string var enabled bool var url string var tags string err := rows.Scan( &id, &createdAt, &updatedAt, &deletedAt, &displayName, &source, &enabled, &url, &tags, ) if err != nil { return items, err } item := domain.SourceEntity{ ID: id, CreatedAt: createdAt, UpdatedAt: updatedAt, DeletedAt: deletedAt, DisplayName: displayName, Source: source, Enabled: enabled, Url: url, Tags: tags, } items = append(items, item) } return items, nil }