Compare commits

..

4 Commits

18 changed files with 581 additions and 119 deletions

2
go.mod
View File

@ -6,9 +6,11 @@ require (
github.com/PuerkitoBio/goquery v1.8.0 github.com/PuerkitoBio/goquery v1.8.0
github.com/glebarez/go-sqlite v1.22.0 github.com/glebarez/go-sqlite v1.22.0
github.com/go-rod/rod v0.107.1 github.com/go-rod/rod v0.107.1
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/huandu/go-sqlbuilder v1.27.1 github.com/huandu/go-sqlbuilder v1.27.1
github.com/joho/godotenv v1.4.0 github.com/joho/godotenv v1.4.0
github.com/labstack/echo-jwt/v4 v4.2.0
github.com/labstack/echo/v4 v4.12.0 github.com/labstack/echo/v4 v4.12.0
github.com/mmcdole/gofeed v1.1.3 github.com/mmcdole/gofeed v1.1.3
github.com/nicklaw5/helix/v2 v2.4.0 github.com/nicklaw5/helix/v2 v2.4.0

4
go.sum
View File

@ -35,6 +35,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
@ -60,6 +62,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/labstack/echo-jwt/v4 v4.2.0 h1:odSISV9JgcSCuhgQSV/6Io3i7nUmfM/QkBeR5GVJj5c=
github.com/labstack/echo-jwt/v4 v4.2.0/go.mod h1:MA2RqdXdEn4/uEglx0HcUOgQSyBaTh5JcaHIan3biwU=
github.com/labstack/echo/v4 v4.12.0 h1:IKpw49IMryVB2p1a4dzwlhP1O2Tf2E0Ir/450lH+kI0= github.com/labstack/echo/v4 v4.12.0 h1:IKpw49IMryVB2p1a4dzwlhP1O2Tf2E0Ir/450lH+kI0=
github.com/labstack/echo/v4 v4.12.0/go.mod h1:UP9Cr2DJXbOK3Kr9ONYzNowSh7HP0aG0ShAyycHSJvM= github.com/labstack/echo/v4 v4.12.0/go.mod h1:UP9Cr2DJXbOK3Kr9ONYzNowSh7HP0aG0ShAyycHSJvM=
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=

View File

@ -72,15 +72,36 @@ CREATE Table Sources (
Tags TEXT NOT NULL Tags TEXT NOT NULL
); );
/*
CREATE TABLE Subscriptions ( CREATE TABLE Subscriptions (
ID INTEGER PRIMARY KEY AUTOINCREMENT, ID INTEGER PRIMARY KEY AUTOINCREMENT,
CreatedAt DATETIME NOT NULL, CreatedAt DATETIME NOT NULL,
UpdatedAt DATETIME NOT NULL, UpdatedAt DATETIME NOT NULL,
DeletedAt DATETIME, DeletedAt DATETIME NOT NULL,
DiscordWebHookID NUMBER NOT NULL, DiscordWebHookID NUMBER NOT NULL,
SourceID NUMBER NOT NULL, SourceID NUMBER NOT NULL,
UserID NUMBER NOT NULL UserID NUMBER NOT NULL
); );
*/
CREATE TABLE UserSourceSubscriptions (
ID INTEGER PRIMARY KEY AUTOINCREMENT,
CreatedAt DATETIME NOT NULL,
UpdatedAt DATETIME NOT NULL,
DeletedAt DATETIME NOT NULL,
UserID NUMBER NOT NULL,
SourceID NUMBER NOT NULL
);
CREATE TABLE AlertDiscord (
ID INTEGER PRIMARY KEY AUTOINCREMENT,
CreatedAt DATETIME NOT NULL,
UpdatedAt DATETIME NOT NULL,
DeletedAt DATETIME NOT NULL,
UserID NUMBER NOT NULL,
SourceID NUMBER NOT NULL,
DiscordWebHookID NUMBER NOT NULL
);
CREATE TABLE Users ( CREATE TABLE Users (
ID INTEGER PRIMARY KEY AUTOINCREMENT, ID INTEGER PRIMARY KEY AUTOINCREMENT,

View File

@ -4,6 +4,18 @@ import (
"time" "time"
) )
// This links a source to a discord webhook.
// It is owned by a user so they can remove the link
type AlertDiscordEntity struct {
ID int64
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
UserID int64
SourceID int64
DiscordWebHookId int64
}
type ArticleEntity struct { type ArticleEntity struct {
ID int64 ID int64
CreatedAt time.Time CreatedAt time.Time
@ -35,10 +47,11 @@ type DiscordWebHookEntity struct {
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt time.Time DeletedAt time.Time
Url string UserID int64
Server string Url string
Channel string Server string
Enabled bool Channel string
Enabled bool
} }
type IconEntity struct { type IconEntity struct {
@ -61,38 +74,50 @@ type SettingEntity struct {
} }
type SourceEntity struct { type SourceEntity struct {
ID int64 ID int64
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt time.Time DeletedAt time.Time
// Who will collect from it. Used // Who will collect from it. Used
// domain.SourceCollector... // domain.SourceCollector...
Source string Source string
// Human Readable value to state what is getting collected // Human Readable value to state what is getting collected
DisplayName string DisplayName string
// Tells the parser where to look for data // Tells the parser where to look for data
Url string Url string
// Static tags for this defined record // Static tags for this defined record
Tags string Tags string
// If the record is disabled, then it will be skipped on processing // If the record is disabled, then it will be skipped on processing
Enabled bool Enabled bool
} }
type SubscriptionEntity struct { //type SubscriptionEntity struct {
ID int64 // ID int64
CreatedAt time.Time // CreatedAt time.Time
UpdatedAt time.Time // UpdatedAt time.Time
DeletedAt time.Time // DeletedAt time.Time
SourceID int64 // UserID int64
SourceType string // SourceID int64
SourceName string // //SourceType string
DiscordID int64 // //SourceName string
DiscordName string // DiscordID int64
// //DiscordName string
//}
// This defines what sources a user wants to follow.
// These will show up for the user as a front page
type UserSourceSubscriptionEntity struct {
ID int64
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
UserID int64
SourceID int64
} }
type UserEntity struct { type UserEntity struct {

View File

@ -0,0 +1,6 @@
package domain
const (
ScopeAll = "newsbot:all"
ScopeRead = "newsbot:read"
)

View File

@ -111,6 +111,13 @@ func (s *Handler) GetDiscordWebHooksByServerAndChannel(c echo.Context) error {
// @Failure 400 {object} domain.BaseResponse // @Failure 400 {object} domain.BaseResponse
// @Failure 500 {object} domain.BaseResponse // @Failure 500 {object} domain.BaseResponse
func (s *Handler) NewDiscordWebHook(c echo.Context) error { func (s *Handler) NewDiscordWebHook(c echo.Context) error {
token, err := s.getJwtToken(c)
if err != nil {
return c.JSON(http.StatusUnauthorized, domain.BaseResponse{
Message: ErrJwtMissing,
})
}
_url := c.QueryParam("url") _url := c.QueryParam("url")
_server := c.QueryParam("server") _server := c.QueryParam("server")
_channel := c.QueryParam("channel") _channel := c.QueryParam("channel")
@ -136,7 +143,12 @@ func (s *Handler) NewDiscordWebHook(c echo.Context) error {
}) })
} }
rows, err := s.repo.DiscordWebHooks.Create(c.Request().Context(), _url, _server, _channel, true) user, err := s.repo.Users.GetByName(token.UserName)
if err != nil {
s.WriteMessage(c, ErrUserUnknown, http.StatusBadRequest)
}
rows, err := s.repo.DiscordWebHooks.Create(c.Request().Context(), user.ID, _url, _server, _channel, true)
if err != nil { if err != nil {
s.WriteError(c, err, http.StatusInternalServerError) s.WriteError(c, err, http.StatusInternalServerError)
} }

View File

@ -3,7 +3,10 @@ package v1
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"github.com/golang-jwt/jwt/v5"
echojwt "github.com/labstack/echo-jwt/v4"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware" "github.com/labstack/echo/v4/middleware"
swagger "github.com/swaggo/echo-swagger" swagger "github.com/swaggo/echo-swagger"
@ -17,7 +20,6 @@ import (
type Handler struct { type Handler struct {
Router *echo.Echo Router *echo.Echo
Db *database.Queries Db *database.Queries
//dto *dto.DtoClient
config services.Configs config services.Configs
repo services.RepositoryService repo services.RepositoryService
} }
@ -28,6 +30,7 @@ const (
ErrUnableToParseId = "Unable to parse the requested ID" ErrUnableToParseId = "Unable to parse the requested ID"
ErrRecordMissing = "The requested record was not found" ErrRecordMissing = "The requested record was not found"
ErrFailedToCreateRecord = "The record was not created due to a database problem" ErrFailedToCreateRecord = "The record was not created due to a database problem"
ErrUserUnknown = "User is unknown"
ResponseMessageSuccess = "Success" ResponseMessageSuccess = "Success"
) )
@ -47,6 +50,13 @@ func NewServer(ctx context.Context, configs services.Configs, conn *sql.DB) *Han
repo: services.NewRepositoryService(conn), repo: services.NewRepositoryService(conn),
} }
jwtConfig := echojwt.Config{
NewClaimsFunc: func(c echo.Context) jwt.Claims {
return new(JwtToken)
},
SigningKey: []byte(configs.JwtSecret),
}
router := echo.New() router := echo.New()
router.Pre(middleware.RemoveTrailingSlash()) router.Pre(middleware.RemoveTrailingSlash())
router.Pre(middleware.Logger()) router.Pre(middleware.Logger())
@ -55,6 +65,7 @@ func NewServer(ctx context.Context, configs services.Configs, conn *sql.DB) *Han
v1 := router.Group("/api/v1") v1 := router.Group("/api/v1")
articles := v1.Group("/articles") articles := v1.Group("/articles")
articles.Use(echojwt.WithConfig(jwtConfig))
articles.GET("", s.listArticles) articles.GET("", s.listArticles)
articles.GET(":id", s.getArticle) articles.GET(":id", s.getArticle)
articles.GET(":id/details", s.getArticleDetails) articles.GET(":id/details", s.getArticleDetails)
@ -76,6 +87,7 @@ func NewServer(ctx context.Context, configs services.Configs, conn *sql.DB) *Han
//settings.GET("/", s.getSettings) //settings.GET("/", s.getSettings)
sources := v1.Group("/sources") sources := v1.Group("/sources")
sources.Use(echojwt.WithConfig(jwtConfig))
sources.GET("", s.listSources) sources.GET("", s.listSources)
sources.GET("/by/source", s.listSourcesBySource) sources.GET("/by/source", s.listSourcesBySource)
sources.GET("/by/sourceAndName", s.GetSourceBySourceAndName) sources.GET("/by/sourceAndName", s.GetSourceBySourceAndName)
@ -89,6 +101,7 @@ func NewServer(ctx context.Context, configs services.Configs, conn *sql.DB) *Han
sources.POST("/:ID/enable", s.enableSource) sources.POST("/:ID/enable", s.enableSource)
subs := v1.Group("/subscriptions") subs := v1.Group("/subscriptions")
subs.Use(echojwt.WithConfig(jwtConfig))
subs.GET("/", s.ListSubscriptions) subs.GET("/", s.ListSubscriptions)
subs.GET("/details", s.ListSubscriptionDetails) subs.GET("/details", s.ListSubscriptionDetails)
subs.GET("/by/discordId", s.GetSubscriptionsByDiscordId) subs.GET("/by/discordId", s.GetSubscriptionsByDiscordId)
@ -120,3 +133,19 @@ func (s *Handler) WriteMessage(c echo.Context, msg string, HttpStatusCode int) e
Message: msg, Message: msg,
}) })
} }
func (h *Handler) getJwtToken(c echo.Context) (JwtToken, error) {
// Make sure that the request came with a jwtToken
token, ok := c.Get("user").(*jwt.Token)
if !ok {
return JwtToken{}, errors.New(ErrJwtMissing)
}
// Generate the claims from the token
claims, ok := token.Claims.(*JwtToken)
if !ok {
return JwtToken{}, errors.New(ErrJwtClaimsMissing)
}
return *claims, nil
}

View File

@ -0,0 +1,99 @@
package v1
import (
"errors"
"strings"
"time"
"git.jamestombleson.com/jtom38/newsbot-api/internal/domain"
"github.com/golang-jwt/jwt/v5"
)
const (
ErrJwtMissing = "auth token is missing"
ErrJwtClaimsMissing = "claims missing on token"
ErrJwtExpired = "auth token has expired"
ErrJwtScopeMissing = "required scope is missing"
)
type JwtToken struct {
Exp time.Time `json:"exp"`
Iss string `json:"iss"`
Authorized bool `json:"authorized"`
UserName string `json:"username"`
Scopes []string `json:"scopes"`
jwt.RegisteredClaims
}
func (j JwtToken) IsValid(scope string) error {
err := j.hasExpired()
if err != nil {
return err
}
err = j.hasScope(scope)
if err != nil {
return err
}
return nil
}
func (j JwtToken) GetUsername() string {
return j.UserName
}
func (j JwtToken) hasExpired() error {
// Check to see if the token has expired
hasExpired := j.Exp.Compare(time.Now())
if hasExpired == -1 {
return errors.New(ErrJwtExpired)
}
return nil
}
func (j JwtToken) hasScope(scope string) error {
// they have the scope to access everything, so let them pass.
if strings.Contains(domain.ScopeAll, scope) {
return nil
}
for _, s := range j.Scopes {
if strings.Contains(s, scope) {
return nil
}
}
return errors.New(ErrJwtScopeMissing)
}
func (h *Handler) generateJwt(username, issuer string) (string, error) {
return h.generateJwtWithExp(username, issuer, time.Now().Add(10*time.Minute))
}
func (h *Handler) generateJwtWithExp(username, issuer string, expiresAt time.Time) (string, error) {
secret := []byte(h.config.JwtSecret)
// Anyone who wants to decrypt the key needs to use the same method
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["exp"] = expiresAt
claims["authorized"] = true
claims["username"] = username
claims["iss"] = issuer
var scopes []string
if username == "admin" {
scopes = append(scopes, domain.ScopeAll)
claims["scopes"] = scopes
} else {
scopes = append(scopes, domain.ScopeRead)
claims["scopes"] = scopes
}
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}

View File

@ -1,39 +0,0 @@
package v1
import (
"encoding/json"
"net/http"
"git.jamestombleson.com/jtom38/newsbot-api/internal/domain"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
)
// GetSettings
// @Summary Returns a object based on the Key that was given.
// @Param key path string true "Settings Key value"
// @Produce application/json
// @Tags Settings
// @Router /v1/settings/{key} [get]
func (s *Handler) getSettings(c echo.Context) error {
id := c.Param("ID")
uuid, err := uuid.Parse(id)
if err != nil {
return c.JSON(http.StatusBadRequest, domain.BaseResponse{
Message: err.Error(),
})
}
res, err := s.Db.GetSourceByID(c.Request().Context(), uuid)
if err != nil {
return c.JSON(http.StatusInternalServerError, err.Error())
}
bResult, err := json.Marshal(res)
if err != nil {
return c.JSON(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, bResult)
}

View File

@ -10,22 +10,11 @@ import (
"git.jamestombleson.com/jtom38/newsbot-api/internal/database" "git.jamestombleson.com/jtom38/newsbot-api/internal/database"
"git.jamestombleson.com/jtom38/newsbot-api/internal/domain" "git.jamestombleson.com/jtom38/newsbot-api/internal/domain"
"git.jamestombleson.com/jtom38/newsbot-api/internal/domain/models"
"git.jamestombleson.com/jtom38/newsbot-api/internal/services" "git.jamestombleson.com/jtom38/newsbot-api/internal/services"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ListSources struct {
ApiStatusModel
Payload []models.SourceDto `json:"payload"`
}
type GetSource struct {
ApiStatusModel
Payload models.SourceDto `json:"payload"`
}
// ListSources // ListSources
// @Summary Lists the top 50 records // @Summary Lists the top 50 records
// @Param page query string false "page number" // @Param page query string false "page number"
@ -35,7 +24,7 @@ type GetSource struct {
// @Success 200 {object} domain.SourcesResponse "ok" // @Success 200 {object} domain.SourcesResponse "ok"
// @Failure 400 {object} domain.BaseResponse "Unable to reach SQL or Data problems" // @Failure 400 {object} domain.BaseResponse "Unable to reach SQL or Data problems"
func (s *Handler) listSources(c echo.Context) error { func (s *Handler) listSources(c echo.Context) error {
resp := domain.SourcesResponse { resp := domain.SourcesResponse{
BaseResponse: domain.BaseResponse{ BaseResponse: domain.BaseResponse{
Message: ResponseMessageSuccess, Message: ResponseMessageSuccess,
}, },
@ -64,8 +53,8 @@ func (s *Handler) listSources(c echo.Context) error {
// @Tags Source // @Tags Source
// @Router /v1/sources/by/source [get] // @Router /v1/sources/by/source [get]
// @Success 200 {object} domain.SourcesResponse "ok" // @Success 200 {object} domain.SourcesResponse "ok"
// @Failure 400 {object} domain.BaseResponse // @Failure 400 {object} domain.BaseResponse
// @Failure 500 {object} domain.BaseResponse // @Failure 500 {object} domain.BaseResponse
func (s *Handler) listSourcesBySource(c echo.Context) error { func (s *Handler) listSourcesBySource(c echo.Context) error {
resp := domain.SourcesResponse{ resp := domain.SourcesResponse{
BaseResponse: domain.BaseResponse{ BaseResponse: domain.BaseResponse{
@ -102,8 +91,8 @@ func (s *Handler) listSourcesBySource(c echo.Context) error {
// @Tags Source // @Tags Source
// @Router /v1/sources/{id} [get] // @Router /v1/sources/{id} [get]
// @Success 200 {object} domain.SourcesResponse "ok" // @Success 200 {object} domain.SourcesResponse "ok"
// @Failure 400 {object} domain.BaseResponse // @Failure 400 {object} domain.BaseResponse
// @Failure 500 {object} domain.BaseResponse // @Failure 500 {object} domain.BaseResponse
func (s *Handler) getSource(c echo.Context) error { func (s *Handler) getSource(c echo.Context) error {
resp := domain.SourcesResponse{ resp := domain.SourcesResponse{
BaseResponse: domain.BaseResponse{ BaseResponse: domain.BaseResponse{
@ -137,8 +126,8 @@ func (s *Handler) getSource(c echo.Context) error {
// @Tags Source // @Tags Source
// @Router /v1/sources/by/sourceAndName [get] // @Router /v1/sources/by/sourceAndName [get]
// @Success 200 {object} domain.SourcesResponse "ok" // @Success 200 {object} domain.SourcesResponse "ok"
// @Failure 400 {object} domain.BaseResponse // @Failure 400 {object} domain.BaseResponse
// @Failure 500 {object} domain.BaseResponse // @Failure 500 {object} domain.BaseResponse
func (s *Handler) GetSourceBySourceAndName(c echo.Context) error { func (s *Handler) GetSourceBySourceAndName(c echo.Context) error {
resp := domain.SourcesResponse{ resp := domain.SourcesResponse{
BaseResponse: domain.BaseResponse{ BaseResponse: domain.BaseResponse{
@ -172,8 +161,8 @@ func (s *Handler) GetSourceBySourceAndName(c echo.Context) error {
// @Tags Source // @Tags Source
// @Router /v1/sources/new/reddit [post] // @Router /v1/sources/new/reddit [post]
// @Success 200 {object} domain.SourcesResponse "ok" // @Success 200 {object} domain.SourcesResponse "ok"
// @Failure 400 {object} domain.BaseResponse // @Failure 400 {object} domain.BaseResponse
// @Failure 500 {object} domain.BaseResponse // @Failure 500 {object} domain.BaseResponse
func (s *Handler) newRedditSource(c echo.Context) error { func (s *Handler) newRedditSource(c echo.Context) error {
resp := domain.SourcesResponse{ resp := domain.SourcesResponse{
BaseResponse: domain.BaseResponse{ BaseResponse: domain.BaseResponse{
@ -339,8 +328,8 @@ func (s *Handler) newTwitchSource(c echo.Context) error {
// @Tags Source // @Tags Source
// @Router /v1/sources/new/rss [post] // @Router /v1/sources/new/rss [post]
// @Success 200 {object} domain.SourcesResponse "ok" // @Success 200 {object} domain.SourcesResponse "ok"
// @Failure 400 {object} domain.BaseResponse // @Failure 400 {object} domain.BaseResponse
// @Failure 500 {object} domain.BaseResponse // @Failure 500 {object} domain.BaseResponse
func (s *Handler) newRssSource(c echo.Context) error { func (s *Handler) newRssSource(c echo.Context) error {
resp := domain.SourcesResponse{ resp := domain.SourcesResponse{
BaseResponse: domain.BaseResponse{ BaseResponse: domain.BaseResponse{
@ -434,10 +423,10 @@ func (s *Handler) deleteSources(c echo.Context) error {
// @Tags Source // @Tags Source
// @Router /v1/sources/{id}/disable [post] // @Router /v1/sources/{id}/disable [post]
// @Success 200 {object} domain.SourcesResponse "ok" // @Success 200 {object} domain.SourcesResponse "ok"
// @Failure 400 {object} domain.BaseResponse // @Failure 400 {object} domain.BaseResponse
// @Failure 500 {object} domain.BaseResponse // @Failure 500 {object} domain.BaseResponse
func (s *Handler) disableSource(c echo.Context) error { func (s *Handler) disableSource(c echo.Context) error {
resp := domain.SourcesResponse { resp := domain.SourcesResponse{
BaseResponse: domain.BaseResponse{ BaseResponse: domain.BaseResponse{
Message: ResponseMessageSuccess, Message: ResponseMessageSuccess,
}, },
@ -476,10 +465,10 @@ func (s *Handler) disableSource(c echo.Context) error {
// @Tags Source // @Tags Source
// @Router /v1/sources/{id}/enable [post] // @Router /v1/sources/{id}/enable [post]
// @Success 200 {object} domain.SourcesResponse "ok" // @Success 200 {object} domain.SourcesResponse "ok"
// @Failure 400 {object} domain.BaseResponse // @Failure 400 {object} domain.BaseResponse
// @Failure 500 {object} domain.BaseResponse // @Failure 500 {object} domain.BaseResponse
func (s *Handler) enableSource(c echo.Context) error { func (s *Handler) enableSource(c echo.Context) error {
resp := domain.SourcesResponse { resp := domain.SourcesResponse{
BaseResponse: domain.BaseResponse{ BaseResponse: domain.BaseResponse{
Message: ResponseMessageSuccess, Message: ResponseMessageSuccess,
}, },

View File

@ -0,0 +1,122 @@
package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"git.jamestombleson.com/jtom38/newsbot-api/internal/domain"
"github.com/huandu/go-sqlbuilder"
)
type AlertDiscordRepo interface {
Create(ctx context.Context, userId, sourceId, webhookId int64) (int64, error)
SoftDelete(ctx context.Context, id int64) (int64, error)
Restore(ctx context.Context, id int64) (int64, error)
Delete(ctx context.Context, id int64) (int64, error)
ListByUser(ctx context.Context, page, limit int, userId int64) ([]domain.AlertDiscordEntity, error)
}
type alertDiscordRepository struct {
conn *sql.DB
defaultLimit int
defaultOffset int
}
func NewAlertDiscordRepository(conn *sql.DB) alertDiscordRepository {
return alertDiscordRepository{
conn: conn,
defaultLimit: 50,
defaultOffset: 50,
}
}
func (r alertDiscordRepository) Create(ctx context.Context, userId, sourceId, webhookId int64) (int64, error) {
dt := time.Now()
queryBuilder := sqlbuilder.NewInsertBuilder()
queryBuilder.InsertInto("AlertDiscord")
queryBuilder.Cols("UpdatedAt", "CreatedAt", "DeletedAt", "UserID", "SourceID", "DiscordWebHookID")
queryBuilder.Values(dt, dt, timeZero, userId, sourceId, webhookId)
query, args := queryBuilder.Build()
_, err := r.conn.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return 1, nil
}
func (r alertDiscordRepository) SoftDelete(ctx context.Context, id int64) (int64, error) {
return softDeleteRow(ctx, r.conn, "AlertDiscord", id)
}
func (r alertDiscordRepository) Restore(ctx context.Context, id int64) (int64, error) {
return restoreRow(ctx, r.conn, "AlertDiscord", id)
}
func (r alertDiscordRepository) Delete(ctx context.Context, id int64) (int64, error) {
return deleteFromTable(ctx, r.conn, "AlertDiscord", id)
}
func (r alertDiscordRepository) ListByUser(ctx context.Context, page, limit int, userId int64) ([]domain.AlertDiscordEntity, error) {
builder := sqlbuilder.NewSelectBuilder()
builder.Select("*")
builder.From("AlertDiscord")
builder.Where(
builder.Equal("UserID", userId),
)
builder.Offset(page * limit)
builder.Limit(limit)
query, args := builder.Build()
rows, err := r.conn.QueryContext(ctx, query, args...)
if err != nil {
return []domain.AlertDiscordEntity{}, err
}
data := r.processRows(rows)
if len(data) == 0 {
return []domain.AlertDiscordEntity{}, errors.New(ErrUserNotFound)
}
return data, nil
}
func (ur alertDiscordRepository) processRows(rows *sql.Rows) []domain.AlertDiscordEntity {
items := []domain.AlertDiscordEntity{}
for rows.Next() {
var id int64
var createdAt time.Time
var updatedAt time.Time
var deletedAt time.Time
var userId int64
var sourceId int64
var webhookId int64
err := rows.Scan(
&id, &createdAt, &updatedAt, &deletedAt,
&userId, &sourceId, &webhookId,
)
if err != nil {
fmt.Println(err)
}
item := domain.AlertDiscordEntity{
ID: id,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
DeletedAt: deletedAt,
UserID: userId,
SourceID: sourceId,
DiscordWebHookId: webhookId,
}
items = append(items, item)
}
return items
}

View File

@ -0,0 +1,63 @@
package repository_test
import (
"context"
"testing"
"time"
"git.jamestombleson.com/jtom38/newsbot-api/internal/domain"
"git.jamestombleson.com/jtom38/newsbot-api/internal/repository"
)
func TestAlertDiscordCreate(t *testing.T) {
t.Log(time.Time{})
db, err := setupInMemoryDb()
if err != nil {
t.Log(err)
t.FailNow()
}
defer db.Close()
r := repository.NewAlertDiscordRepository(db)
created, err := r.Create(context.Background(), 1, 1, 1)
if err != nil {
t.Log(err)
t.FailNow()
}
if created != 1 {
t.Log("failed to create the record")
t.FailNow()
}
}
func TestAlertDiscordCreateAndValidate(t *testing.T) {
t.Log(time.Time{})
db, err := setupInMemoryDb()
if err != nil {
t.Log(err)
t.FailNow()
}
defer db.Close()
source := repository.NewSourceRepository(db)
source.Create(context.Background(), domain.SourceCollectorRss, "Unit Testing", "www.fake.com", "testing,units", true)
sourceRecord, _ := source.GetBySourceAndName(context.Background(), domain.SourceCollectorRss, "Unit Testing")
webhookRepo := repository.NewDiscordWebHookRepository(db)
webhookRepo.Create(context.Background(), 999, "discord.com", "Unit Testing", "memes", true)
webhook, _ := webhookRepo.GetByUrl(context.Background(), "discord.com")
r := repository.NewAlertDiscordRepository(db)
r.Create(context.Background(), 999, sourceRecord.ID, webhook.ID)
alert, err := r.ListByUser(context.Background(), 0, 10, 999)
if err != nil {
t.Error(err)
t.FailNow()
}
if len(alert) != 1 {
t.Error("got the incorrect number of rows back")
t.FailNow()
}
}

View File

@ -9,8 +9,8 @@ import (
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
) )
type DiscordWebHookRepo interface{ type DiscordWebHookRepo interface {
Create(ctx context.Context, url, server, channel string, enabled bool) (int64, error) Create(ctx context.Context, userId int64, url, server, channel string, enabled bool) (int64, error)
Enable(ctx context.Context, id int64) (int64, error) Enable(ctx context.Context, id int64) (int64, error)
Disable(ctx context.Context, id int64) (int64, error) Disable(ctx context.Context, id int64) (int64, error)
SoftDelete(ctx context.Context, id int64) (int64, error) SoftDelete(ctx context.Context, id int64) (int64, error)
@ -32,12 +32,12 @@ func NewDiscordWebHookRepository(conn *sql.DB) discordWebHookRepository {
} }
} }
func (r discordWebHookRepository) Create(ctx context.Context, url, server, channel string, enabled bool) (int64, error) { func (r discordWebHookRepository) Create(ctx context.Context, userId int64, url, server, channel string, enabled bool) (int64, error) {
dt := time.Now() dt := time.Now()
queryBuilder := sqlbuilder.NewInsertBuilder() queryBuilder := sqlbuilder.NewInsertBuilder()
queryBuilder.InsertInto("DiscordWebHooks") queryBuilder.InsertInto("DiscordWebHooks")
queryBuilder.Cols("UpdatedAt", "CreatedAt", "DeletedAt", "Url", "Server", "Channel", "Enabled") queryBuilder.Cols("UpdatedAt", "CreatedAt", "DeletedAt", "UserID", "Url", "Server", "Channel", "Enabled")
queryBuilder.Values(dt, dt, timeZero, url, server, channel, enabled) queryBuilder.Values(dt, dt, timeZero, userId, url, server, channel, enabled)
query, args := queryBuilder.Build() query, args := queryBuilder.Build()
_, err := r.conn.ExecContext(ctx, query, args...) _, err := r.conn.ExecContext(ctx, query, args...)
@ -195,13 +195,14 @@ func (r discordWebHookRepository) processRows(rows *sql.Rows) ([]domain.DiscordW
var createdAt time.Time var createdAt time.Time
var updatedAt time.Time var updatedAt time.Time
var deletedAt time.Time var deletedAt time.Time
var userId int64
var url string var url string
var server string var server string
var channel string var channel string
var enabled bool var enabled bool
err := rows.Scan( err := rows.Scan(
&id, &createdAt, &updatedAt, &id, &createdAt, &updatedAt,
&deletedAt, &url, &server, &deletedAt, &userId, &url, &server,
&channel, &enabled, &channel, &enabled,
) )
if err != nil { if err != nil {
@ -213,6 +214,7 @@ func (r discordWebHookRepository) processRows(rows *sql.Rows) ([]domain.DiscordW
CreatedAt: createdAt, CreatedAt: createdAt,
UpdatedAt: updatedAt, UpdatedAt: updatedAt,
DeletedAt: deletedAt, DeletedAt: deletedAt,
UserID: userId,
Url: url, Url: url,
Server: server, Server: server,
Channel: channel, Channel: channel,

View File

@ -17,7 +17,7 @@ func TestCreateDiscordWebHookRecord(t *testing.T) {
defer db.Close() defer db.Close()
r := repository.NewDiscordWebHookRepository(db) r := repository.NewDiscordWebHookRepository(db)
created, err := r.Create(context.Background(), "www.discord.com/bad/webhook", "Unit Testing", "memes", true) created, err := r.Create(context.Background(), 999, "www.discord.com/bad/webhook", "Unit Testing", "memes", true)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
t.FailNow() t.FailNow()
@ -38,7 +38,7 @@ func TestDiscordWebHookGetById(t *testing.T) {
defer db.Close() defer db.Close()
ctx := context.Background() ctx := context.Background()
r := repository.NewDiscordWebHookRepository(db) r := repository.NewDiscordWebHookRepository(db)
created, err := r.Create(ctx, "www.discord.com/bad/webhook", "Unit Testing", "memes", true) created, err := r.Create(ctx, 999, "www.discord.com/bad/webhook", "Unit Testing", "memes", true)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
t.FailNow() t.FailNow()
@ -71,7 +71,7 @@ func TestDiscordWebHookGetByUrl(t *testing.T) {
ctx := context.Background() ctx := context.Background()
r := repository.NewDiscordWebHookRepository(db) r := repository.NewDiscordWebHookRepository(db)
_, _ = r.Create(ctx, "www.discord.com/bad/webhook", "Unit Testing", "memes", true) _, _ = r.Create(ctx, 999, "www.discord.com/bad/webhook", "Unit Testing", "memes", true)
item, err := r.GetByUrl(ctx, "www.discord.com/bad/webhook") item, err := r.GetByUrl(ctx, "www.discord.com/bad/webhook")
if err != nil { if err != nil {
t.Log(err) t.Log(err)
@ -95,7 +95,7 @@ func TestDiscordWebHookListByServerName(t *testing.T) {
ctx := context.Background() ctx := context.Background()
serverName := "Unit Testing" serverName := "Unit Testing"
r := repository.NewDiscordWebHookRepository(db) r := repository.NewDiscordWebHookRepository(db)
_, _ = r.Create(ctx, "www.discord.com/bad/webhook", serverName, "memes", true) _, _ = r.Create(ctx, 999, "www.discord.com/bad/webhook", serverName, "memes", true)
item, err := r.ListByServerName(ctx, serverName) item, err := r.ListByServerName(ctx, serverName)
if err != nil { if err != nil {
@ -121,7 +121,7 @@ func TestDiscordWebHookListByServerAndChannel(t *testing.T) {
serverName := "Unit Testing" serverName := "Unit Testing"
channel := "memes" channel := "memes"
r := repository.NewDiscordWebHookRepository(db) r := repository.NewDiscordWebHookRepository(db)
_, _ = r.Create(ctx, "www.discord.com/bad/webhook", serverName, channel, true) _, _ = r.Create(ctx, 999, "www.discord.com/bad/webhook", serverName, channel, true)
item, err := r.ListByServerAndChannel(ctx, serverName, channel) item, err := r.ListByServerAndChannel(ctx, serverName, channel)
if err != nil { if err != nil {
@ -152,7 +152,7 @@ func TestDiscordWebHookEnableRecord(t *testing.T) {
serverName := "Unit Testing" serverName := "Unit Testing"
channel := "memes" channel := "memes"
r := repository.NewDiscordWebHookRepository(db) r := repository.NewDiscordWebHookRepository(db)
_, _ = r.Create(ctx, "www.discord.com/bad/webhook", serverName, channel, false) _, _ = r.Create(ctx, 999, "www.discord.com/bad/webhook", serverName, channel, false)
item, err := r.GetById(ctx, 1) item, err := r.GetById(ctx, 1)
if err != nil { if err != nil {
@ -195,7 +195,7 @@ func TestDiscordWebHookDisableRecord(t *testing.T) {
serverName := "Unit Testing" serverName := "Unit Testing"
channel := "memes" channel := "memes"
r := repository.NewDiscordWebHookRepository(db) r := repository.NewDiscordWebHookRepository(db)
_, _ = r.Create(ctx, "www.discord.com/bad/webhook", serverName, channel, true) _, _ = r.Create(ctx, 999, "www.discord.com/bad/webhook", serverName, channel, true)
item, err := r.GetById(ctx, 1) item, err := r.GetById(ctx, 1)
if err != nil { if err != nil {
@ -238,7 +238,7 @@ func TestDiscordWebHookSoftDelete(t *testing.T) {
serverName := "Unit Testing" serverName := "Unit Testing"
channel := "memes" channel := "memes"
r := repository.NewDiscordWebHookRepository(db) r := repository.NewDiscordWebHookRepository(db)
_, _ = r.Create(ctx, "www.discord.com/bad/webhook", serverName, channel, true) _, _ = r.Create(ctx, 999, "www.discord.com/bad/webhook", serverName, channel, true)
_, err = r.SoftDelete(ctx, 1) _, err = r.SoftDelete(ctx, 1)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
@ -263,7 +263,7 @@ func TestDiscordWebHookRestore(t *testing.T) {
timeZero := time.Time{} timeZero := time.Time{}
r := repository.NewDiscordWebHookRepository(db) r := repository.NewDiscordWebHookRepository(db)
_, _ = r.Create(ctx, "www.discord.com/bad/webhook", serverName, channel, true) _, _ = r.Create(ctx, 999, "www.discord.com/bad/webhook", serverName, channel, true)
item, _ := r.GetById(ctx, 1) item, _ := r.GetById(ctx, 1)
if item.DeletedAt != timeZero { if item.DeletedAt != timeZero {
t.Log("DeletedAt was not zero") t.Log("DeletedAt was not zero")

View File

@ -0,0 +1,120 @@
package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"git.jamestombleson.com/jtom38/newsbot-api/internal/domain"
"github.com/huandu/go-sqlbuilder"
)
type UserSourceRepo interface {
Create(ctx context.Context, userId, sourceId int64) (int64, error)
SoftDelete(ctx context.Context, id int64) (int64, error)
Restore(ctx context.Context, id int64) (int64, error)
Delete(ctx context.Context, id int64) (int64, error)
ListByUser(ctx context.Context, page, limit int, userId int64) ([]domain.UserSourceSubscriptionEntity, error)
}
type userSourceRepository struct {
conn *sql.DB
defaultLimit int
defaultOffset int
}
func NewUserSourceRepository(conn *sql.DB) userSourceRepository {
return userSourceRepository{
conn: conn,
defaultLimit: 50,
defaultOffset: 50,
}
}
func (r userSourceRepository) Create(ctx context.Context, userId, sourceId int64) (int64, error) {
dt := time.Now()
queryBuilder := sqlbuilder.NewInsertBuilder()
queryBuilder.InsertInto("UserSourceSubscriptions")
queryBuilder.Cols("UpdatedAt", "CreatedAt", "DeletedAt", "UserID", "SourceID")
queryBuilder.Values(dt, dt, timeZero, userId, sourceId)
query, args := queryBuilder.Build()
_, err := r.conn.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return 1, nil
}
func (r userSourceRepository) SoftDelete(ctx context.Context, id int64) (int64, error) {
return softDeleteRow(ctx, r.conn, "UserSourceSubscriptions", id)
}
func (r userSourceRepository) Restore(ctx context.Context, id int64) (int64, error) {
return restoreRow(ctx, r.conn, "UserSourceSubscriptions", id)
}
func (r userSourceRepository) Delete(ctx context.Context, id int64) (int64, error) {
return deleteFromTable(ctx, r.conn, "UserSourceSubscriptions", id)
}
func (r userSourceRepository) ListByUser(ctx context.Context, page, limit int, userId int64) ([]domain.UserSourceSubscriptionEntity, error) {
builder := sqlbuilder.NewSelectBuilder()
builder.Select("*")
builder.From("UserSourceSubscriptions")
builder.Where(
builder.Equal("UserID", userId),
)
builder.Offset(page * limit)
builder.Limit(limit)
query, args := builder.Build()
rows, err := r.conn.QueryContext(ctx, query, args...)
if err != nil {
return []domain.UserSourceSubscriptionEntity{}, err
}
data := r.processRows(rows)
if len(data) == 0 {
return []domain.UserSourceSubscriptionEntity{}, errors.New(ErrUserNotFound)
}
return data, nil
}
func (ur userSourceRepository) processRows(rows *sql.Rows) []domain.UserSourceSubscriptionEntity {
items := []domain.UserSourceSubscriptionEntity{}
for rows.Next() {
var id int64
var createdAt time.Time
var updatedAt time.Time
var deletedAt time.Time
var userId int64
var sourceId int64
err := rows.Scan(
&id, &createdAt, &updatedAt, &deletedAt,
&userId, &sourceId,
)
if err != nil {
fmt.Println(err)
}
item := domain.UserSourceSubscriptionEntity{
ID: id,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
DeletedAt: deletedAt,
UserID: userId,
SourceID: sourceId,
}
items = append(items, item)
}
return items
}

View File

@ -34,6 +34,7 @@ const (
type Configs struct { type Configs struct {
ServerAddress string ServerAddress string
JwtSecret string
RedditEnabled bool RedditEnabled bool
RedditPullTop bool RedditPullTop bool
@ -64,6 +65,7 @@ func NewConfig() ConfigClient {
func GetEnvConfig() Configs { func GetEnvConfig() Configs {
return Configs{ return Configs{
ServerAddress: os.Getenv(ServerAddress), ServerAddress: os.Getenv(ServerAddress),
JwtSecret: os.Getenv("JwtSecret"),
RedditEnabled: processBoolConfig(os.Getenv(FEATURE_ENABLE_REDDIT_BACKEND)), RedditEnabled: processBoolConfig(os.Getenv(FEATURE_ENABLE_REDDIT_BACKEND)),
RedditPullTop: processBoolConfig(os.Getenv(REDDIT_PULL_TOP)), RedditPullTop: processBoolConfig(os.Getenv(REDDIT_PULL_TOP)),

View File

@ -7,19 +7,23 @@ import (
) )
type RepositoryService struct { type RepositoryService struct {
Articles repository.ArticlesRepo AlertDiscord repository.AlertDiscordRepo
DiscordWebHooks repository.DiscordWebHookRepo Articles repository.ArticlesRepo
Sources repository.Sources DiscordWebHooks repository.DiscordWebHookRepo
Users repository.Users RefreshTokens repository.RefreshToken
RefreshTokens repository.RefreshToken Sources repository.Sources
Users repository.Users
UserSourceSubscriptions repository.UserSourceRepo
} }
func NewRepositoryService(conn *sql.DB) RepositoryService { func NewRepositoryService(conn *sql.DB) RepositoryService {
return RepositoryService{ return RepositoryService{
Articles: repository.NewArticleRepository(conn), AlertDiscord: repository.NewAlertDiscordRepository(conn),
DiscordWebHooks: repository.NewDiscordWebHookRepository(conn), Articles: repository.NewArticleRepository(conn),
Sources: repository.NewSourceRepository(conn), DiscordWebHooks: repository.NewDiscordWebHookRepository(conn),
Users: repository.NewUserRepository(conn), RefreshTokens: repository.NewRefreshTokenRepository(conn),
RefreshTokens: repository.NewRefreshTokenRepository(conn), Sources: repository.NewSourceRepository(conn),
Users: repository.NewUserRepository(conn),
UserSourceSubscriptions: repository.NewUserSourceRepository(conn),
} }
} }

View File

@ -6,6 +6,7 @@ build: ## builds the application with the current go runtime
~/go/bin/swag f ~/go/bin/swag f
~/go/bin/swag init -g cmd/server.go ~/go/bin/swag init -g cmd/server.go
go build cmd/server.go go build cmd/server.go
ls -lh server
docker-build: ## Generates the docker image docker-build: ## Generates the docker image
docker build -t "newsbot.collector.api" . docker build -t "newsbot.collector.api" .