diff --git a/internal/domain/entities.go b/internal/domain/entities.go index 70091e7..edd2557 100644 --- a/internal/domain/entities.go +++ b/internal/domain/entities.go @@ -15,7 +15,6 @@ type RefreshTokenEntity struct { Id int64 Username string Token string - ExpiresAt time.Time CreatedAt time.Time LastUpdated time.Time } diff --git a/internal/migrations/20240416180636_refreshtoken.sql b/internal/migrations/20240416180636_refreshtoken.sql index 7457fb4..ce13bff 100644 --- a/internal/migrations/20240416180636_refreshtoken.sql +++ b/internal/migrations/20240416180636_refreshtoken.sql @@ -5,7 +5,6 @@ CREATE TABLE RefreshTokens ( ID INTEGER PRIMARY KEY AUTOINCREMENT, Username TEXT NOT NULL, Token TEXT NOT NULL, - ExpiresAt DATETIME NOT NULL, CreatedAt DATETIME NOT NULL, LastUpdated DATETIME NOT NULL ) diff --git a/internal/repositories/refreshTokens.go b/internal/repositories/refreshTokens.go index 16e3648..9e8889b 100644 --- a/internal/repositories/refreshTokens.go +++ b/internal/repositories/refreshTokens.go @@ -15,7 +15,7 @@ const ( ) type RefreshTokenTable interface { - Create(username string, token string, expiresAt time.Time) (int64, error) + Create(username string, token string) (int64, error) GetByUsername(name string) (domain.RefreshTokenEntity, error) DeleteById(id int64) (int64, error) } @@ -30,12 +30,12 @@ func NewRefreshTokenRepository(conn *sql.DB) RefreshTokenRepository { } } -func (rt RefreshTokenRepository) Create(username string, token string, expiresAt time.Time) (int64, error) { +func (rt RefreshTokenRepository) Create(username string, token string) (int64, error) { dt := time.Now() builder := sqlbuilder.NewInsertBuilder() builder.InsertInto(refreshTokenTableName) - builder.Cols("Username", "Token", "ExpiresAt", "CreatedAt", "LastUpdated") - builder.Values(username, token, expiresAt, dt, dt) + builder.Cols("Username", "Token", "CreatedAt", "LastUpdated") + builder.Values(username, token, dt, dt) query, args := builder.Build() _, err := rt.connection.Exec(query, args...) @@ -89,11 +89,10 @@ func (rd RefreshTokenRepository) processRows(rows *sql.Rows) []domain.RefreshTok var id int64 var username string var token string - var expiresAt time.Time var createdAt time.Time var lastUpdated time.Time - err := rows.Scan(&id, &username, &token, &expiresAt, &createdAt, &lastUpdated) + err := rows.Scan(&id, &username, &token, &createdAt, &lastUpdated) if err != nil { fmt.Println(err) } @@ -102,7 +101,6 @@ func (rd RefreshTokenRepository) processRows(rows *sql.Rows) []domain.RefreshTok Id: id, Username: username, Token: token, - ExpiresAt: expiresAt, CreatedAt: createdAt, LastUpdated: lastUpdated, }) diff --git a/internal/repositories/refreshTokens_test.go b/internal/repositories/refreshTokens_test.go index 1cd14f4..bd3a98a 100644 --- a/internal/repositories/refreshTokens_test.go +++ b/internal/repositories/refreshTokens_test.go @@ -3,7 +3,6 @@ package repositories_test import ( "database/sql" "testing" - "time" "git.jamestombleson.com/jtom38/go-cook/internal/repositories" _ "github.com/glebarez/go-sqlite" @@ -18,7 +17,7 @@ func TestRefreshTokenCreate(t *testing.T) { } client := repositories.NewRefreshTokenRepository(conn) - rows, err := client.Create("tester", "BadTokenDontUse", time.Now().Add(time.Hour+1)) + rows, err := client.Create("tester", "BadTokenDontUse") if err != nil { t.Log(err) t.FailNow() @@ -37,7 +36,7 @@ func TestRefreshTokenGetByUsername(t *testing.T) { } client := repositories.NewRefreshTokenRepository(conn) - rows, err := client.Create("tester", "BadTokenDoNotUse", time.Now().Add(time.Hour+1)) + rows, err := client.Create("tester", "BadTokenDoNotUse") if err != nil { t.Log(err) t.FailNow() @@ -68,7 +67,7 @@ func TestRefreshTokenDeleteById(t *testing.T) { } client := repositories.NewRefreshTokenRepository(conn) - _, err = client.Create("tester", "BadTokenDoNotUse", time.Now().Add(time.Hour+1)) + _, err = client.Create("tester", "BadTokenDoNotUse") if err != nil { t.Log(err) t.FailNow() diff --git a/internal/services/refreshTokenService.go b/internal/services/refreshTokenService.go index a23c06d..aa1f457 100644 --- a/internal/services/refreshTokenService.go +++ b/internal/services/refreshTokenService.go @@ -3,20 +3,25 @@ package services import ( "database/sql" "errors" - "time" "git.jamestombleson.com/jtom38/go-cook/internal/domain" "git.jamestombleson.com/jtom38/go-cook/internal/repositories" "github.com/google/uuid" ) +const ( + ErrUnexpectedAmountOfRowsUpdated = "got a unexpected of rows updated" +) + type RefreshToken interface { - Create(username string, expiresAt time.Time) (string, error) + Create(username string) (string, error) GetByName(name string) (domain.RefreshTokenEntity, error) Delete(id int64) (int64, error) - IsRequestValid(username, refreshToken string, jwtExpiresAt time.Time) error + IsRequestValid(username, refreshToken string) error } +// A new jwt token can be made if the user has the correct refresh token for the user. +// It will also require the old JWT token so the expire time is pulled and part of the validation type RefreshTokenService struct { table repositories.RefreshTokenTable } @@ -27,13 +32,26 @@ func NewRefreshTokenService(conn *sql.DB) RefreshTokenService { } } -func (rt RefreshTokenService) Create(username string, expiresAt time.Time) (string, error) { +func (rt RefreshTokenService) Create(username string) (string, error) { + //if a refresh token already exists for a user, reuse + existingToken, err := rt.GetByName(username) + if err == nil { + rowsRemoved, err := rt.Delete(existingToken.Id) + if err != nil { + return "", err + } + + if rowsRemoved != 1 { + return "", errors.New(ErrUnexpectedAmountOfRowsUpdated) + } + } + token, err := uuid.NewV7() if err != nil { return "", err } - rows, err := rt.table.Create(username, token.String(), expiresAt) + rows, err := rt.table.Create(username, token.String()) if err != nil { return "", err } @@ -54,19 +72,15 @@ func (rt RefreshTokenService) Delete(id int64) (int64, error) { return rt.table.DeleteById(id) } -func (rt RefreshTokenService) IsRequestValid(username, refreshToken string, jwtExpiresAt time.Time) error { +func (rt RefreshTokenService) IsRequestValid(username, refreshToken string) error { token, err := rt.GetByName(username) if err != nil { return err } - - if (token.Token != refreshToken) { + + if token.Token != refreshToken { return errors.New("the refresh token given does not match") } - if (token.ExpiresAt != jwtExpiresAt) { - return errors.New("the time when the jwt token expires does not match what was given") - } - return nil -} \ No newline at end of file +}