mirror of
https://github.com/muety/wakapi.git
synced 2025-12-05 22:20:24 -08:00
Compare commits
4 Commits
faa6312cd8
...
c320afaf3b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c320afaf3b | ||
|
|
140f8b2eac | ||
|
|
6880d4d524 | ||
|
|
a07a4f48b4 |
10
README.md
10
README.md
@@ -200,13 +200,18 @@ You can specify configuration options either via a config file (default: `config
|
||||
| `security.signup_max_rate` /<br> `WAKAPI_SIGNUP_MAX_RATE` | `5/1h` | Rate limiting config for signup endpoint in format `<max_req>/<multiplier><unit>`, where `unit` is one of `s`, `m` or `h`. |
|
||||
| `security.login_max_rate` /<br> `WAKAPI_LOGIN_MAX_RATE` | `10/1m` | Rate limiting config for login endpoint in format `<max_req>/<multiplier><unit>`, where `unit` is one of `s`, `m` or `h`. |
|
||||
| `security.password_reset_max_rate` /<br> `WAKAPI_PASSWORD_RESET_MAX_RATE` | `5/1h` | Rate limiting config for password reset endpoint in format `<max_req>/<multiplier><unit>`, where `unit` is one of `s`, `m` or `h`. |
|
||||
| `security.oidc` | `[]` | List of OpenID Connect provider configurations (for details, see [wiki](https://github.com/muety/wakapi/wiki/OpenID-Connect-login-(SSO))) |
|
||||
| `security.oidc[0].name` /<br> `WAKAPI_OIDC_PROVIDER_NAME` | - | Name / identifier for the OpenID Connect provider (e.g. `gitlab`) |
|
||||
| `security.oidc[0].client_id` /<br> `WAKAPI_OIDC_PROVIDER_CLIENT_ID` | - | OAuth client name with this provider |
|
||||
| `security.oidc[0].client_secret` /<br> `WAKAPI_OIDC_PROVIDER_CLIENT_SECRET` | - | OAuth client secret with this provider |
|
||||
| `security.oidc[0].endpoint` /<br> `WAKAPI_OIDC_PROVIDER_ENDPOINT` | - | OpenID Connect provider API entrypoint (for [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html)) |
|
||||
| `db.host` /<br> `WAKAPI_DB_HOST` | - | Database host |
|
||||
| `db.port` /<br> `WAKAPI_DB_PORT` | - | Database port |
|
||||
| `db.socket` /<br> `WAKAPI_DB_SOCKET` | - | Database UNIX socket (alternative to `host`) (for MySQL only) |
|
||||
| `db.user` /<br> `WAKAPI_DB_USER` | - | Database user |
|
||||
| `db.password` /<br> `WAKAPI_DB_PASSWORD` | - | Database password |
|
||||
| `db.name` /<br> `WAKAPI_DB_NAME` | `wakapi_db.db` | Database name |
|
||||
| `db.dialect` /<br> `WAKAPI_DB_TYPE` | `sqlite3` | Database type (one of `sqlite3`, `mysql`, `postgres`) |
|
||||
| `db.dialect` /<br> `WAKAPI_DB_TYPE` | `sqlite3` | Database type (one of `sqlite3`, `mysql`, `postgres`) |
|
||||
| `db.charset` /<br> `WAKAPI_DB_CHARSET` | `utf8mb4` | Database connection charset (for MySQL only) |
|
||||
| `db.max_conn` /<br> `WAKAPI_DB_MAX_CONNECTIONS` | `2` | Maximum number of database connections |
|
||||
| `db.ssl` /<br> `WAKAPI_DB_SSL` | `false` | Whether to use TLS encryption for database connection (Postgres only) |
|
||||
@@ -252,7 +257,8 @@ Wakapi supports different types of user authentication.
|
||||
* Warning: This type of authentication is quite prone to misconfiguration. Make sure that your reverse proxy properly strips relevant headers from client requests.
|
||||
|
||||
### Single Sign-On / OpenID Connect
|
||||
Wakapi supports login via external identity providers via OpenID Connect. See [our wiki](https://github.com/muety/wakapi/wiki/OpenID-Connect-login-(SSO)) for details.
|
||||
|
||||
Wakapi supports login via external identity providers via OpenID Connect. See [our wiki](https://github.com/muety/wakapi/wiki/OpenID-Connect-login-(SSO)) for details.
|
||||
|
||||
## 🔧 API endpoints
|
||||
|
||||
|
||||
@@ -205,10 +205,10 @@ type SMTPMailConfig struct {
|
||||
}
|
||||
|
||||
type oidcProviderConfig struct {
|
||||
Name string `yaml:"name" env:"WAKAPI_OIDC_PROVIDERS_0_NAME"`
|
||||
ClientID string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
Endpoint string `yaml:"endpoint"` // base url from which auto-discovery (.well-known/openid-configuration) can be found
|
||||
Name string `yaml:"name" env:"WAKAPI_OIDC_PROVIDER_NAME"`
|
||||
ClientID string `yaml:"client_id" env:"WAKAPI_OIDC_PROVIDER_CLIENT_ID"`
|
||||
ClientSecret string `yaml:"client_secret" env:"WAKAPI_OIDC_PROVIDER_CLIENT_SECRET"`
|
||||
Endpoint string `yaml:"endpoint" env:"WAKAPI_OIDC_PROVIDER_ENDPOINT"` // base url from which auto-discovery (.well-known/openid-configuration) can be found
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
|
||||
@@ -25,7 +25,7 @@ type IdTokenPayload struct {
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
ProviderName string `json:"-"`
|
||||
ProviderName string `json:"provider_name"` // custom field, not part of actual id token response
|
||||
}
|
||||
|
||||
func (token *IdTokenPayload) Exp() time.Time {
|
||||
|
||||
13
config/testutils.go
Normal file
13
config/testutils.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package config
|
||||
|
||||
func WithOidcProvider(c *Config, name, clientId, clientSecret, Endpoint string) *Config {
|
||||
providerConf := oidcProviderConfig{
|
||||
Name: name,
|
||||
ClientID: clientId,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: Endpoint,
|
||||
}
|
||||
c.Security.OidcProviders = append(c.Security.OidcProviders, providerConf)
|
||||
RegisterOidcProvider(&providerConf) // config must be Set() for this to work
|
||||
return c
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
3
go.mod
3
go.mod
@@ -29,6 +29,7 @@ require (
|
||||
github.com/mileusna/useragent v1.3.5
|
||||
github.com/muety/artifex/v2 v2.0.1-0.20221201142708-74e7d3f6feaf
|
||||
github.com/narqo/go-badge v0.0.0-20230821190521-c9a75c019a59
|
||||
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
@@ -45,6 +46,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/go-jose/go-jose/v3 v3.0.1 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
github.com/go-openapi/swag/conv v0.25.1 // indirect
|
||||
github.com/go-openapi/swag/jsonname v0.25.1 // indirect
|
||||
@@ -53,6 +55,7 @@ require (
|
||||
github.com/go-openapi/swag/stringutils v0.25.1 // indirect
|
||||
github.com/go-openapi/swag/typeutils v0.25.1 // indirect
|
||||
github.com/go-openapi/swag/yamlutils v0.25.1 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -74,6 +74,8 @@ github.com/go-chi/httprate v0.15.0 h1:j54xcWV9KGmPf/X4H32/aTH+wBlrvxL7P+SdnRqxh5
|
||||
github.com/go-chi/httprate v0.15.0/go.mod h1:rzGHhVrsBn3IMLYDOZQsSU4fJNWcjui4fWKJcCId1R4=
|
||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
||||
github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA=
|
||||
github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-openapi/jsonpointer v0.22.1 h1:sHYI1He3b9NqJ4wXLoJDKmUmHkWy/L7rtEo92JUxBNk=
|
||||
@@ -115,6 +117,7 @@ github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei
|
||||
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
@@ -189,6 +192,8 @@ github.com/narqo/go-badge v0.0.0-20230821190521-c9a75c019a59 h1:kbREB9muGo4sHLoZ
|
||||
github.com/narqo/go-badge v0.0.0-20230821190521-c9a75c019a59/go.mod h1:m9BzkaxwU4IfPQi9ko23cmuFltayFe8iS0dlRlnEWiM=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25 h1:9bCMuD3TcnjeqjPT2gSlha4asp8NvgcFRYExCaikCxk=
|
||||
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25/go.mod h1:eDjgYHYDJbPLBLsyZ6qRaugP0mX8vePOhZ5id1fdzJw=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||
@@ -225,6 +230,7 @@ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
@@ -254,6 +260,7 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
@@ -386,6 +393,7 @@ golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE=
|
||||
golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
|
||||
@@ -75,6 +75,8 @@ func (m *AuthenticateMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques
|
||||
var user *models.User
|
||||
|
||||
if m.tryHandleOidc(w, r) {
|
||||
// user has expired oidc token, thus is redirected to provider and will come back to callback endpoint
|
||||
// notably, if user does have a valid, non-expired id token, they will also have a valid auth cookie, so proceed as usual
|
||||
return
|
||||
}
|
||||
|
||||
@@ -213,13 +215,14 @@ func (m *AuthenticateMiddleware) tryGetUserByCookie(r *http.Request) (*models.Us
|
||||
}
|
||||
|
||||
// redirect if oidc id token was found, but expired
|
||||
// returns true if further authentication can be skipped
|
||||
func (m *AuthenticateMiddleware) tryHandleOidc(w http.ResponseWriter, r *http.Request) bool {
|
||||
idToken := routeutils.GetOidcIdTokenPayload(r)
|
||||
if idToken == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !idToken.IsValid() {
|
||||
if !idToken.IsValid() { // expired
|
||||
provider, err := m.config.Security.GetOidcProvider(idToken.ProviderName)
|
||||
if err != nil {
|
||||
conf.Log().Request(r).Error("failed to get provider from id token", "provider", idToken.ProviderName, "sub", idToken.Subject)
|
||||
|
||||
@@ -5,10 +5,15 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muety/wakapi/config"
|
||||
routeutils "github.com/muety/wakapi/routes/utils"
|
||||
"github.com/oauth2-proxy/mockoidc"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/muety/wakapi/mocks"
|
||||
@@ -250,4 +255,107 @@ func TestAuthenticateMiddleware_tryGetUserByTrustedHeader_Signup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticateMiddleware_tryHandleOidc_NoToken(t *testing.T) {
|
||||
config.Set(config.Empty())
|
||||
|
||||
userServiceMock := new(mocks.UserServiceMock)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/summary", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
sut := NewAuthenticateMiddleware(userServiceMock)
|
||||
|
||||
assert.False(t, sut.tryHandleOidc(w, r))
|
||||
assert.NotEqual(t, w.Code, http.StatusTemporaryRedirect)
|
||||
assert.NotEqual(t, w.Code, http.StatusFound)
|
||||
}
|
||||
|
||||
func TestAuthenticateMiddleware_tryHandleOidc_InvalidToken_ExistingUser(t *testing.T) {
|
||||
const (
|
||||
testProvider = "mock"
|
||||
testSub = "testsub"
|
||||
)
|
||||
var testUser = &models.User{ID: "testuser"}
|
||||
|
||||
oidcMock, _ := mockoidc.Run()
|
||||
defer oidcMock.Shutdown()
|
||||
|
||||
cfg := config.Empty()
|
||||
config.Set(cfg)
|
||||
config.WithOidcProvider(cfg, testProvider, oidcMock.ClientID, oidcMock.ClientSecret, oidcMock.Addr()+"/oidc")
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/summary", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
testIdToken := &config.IdTokenPayload{
|
||||
Subject: testSub,
|
||||
Expiry: time.Now().Add(-time.Minute).Unix(),
|
||||
ProviderName: testProvider,
|
||||
}
|
||||
routeutils.SetOidcIdTokenPayload(testIdToken, r, w)
|
||||
|
||||
userServiceMock := new(mocks.UserServiceMock)
|
||||
userServiceMock.On("GetUserByOidc", testProvider, testSub).Return(testUser, nil)
|
||||
|
||||
sut := NewAuthenticateMiddleware(userServiceMock)
|
||||
|
||||
assert.True(t, sut.tryHandleOidc(w, r))
|
||||
assert.Equal(t, w.Code, http.StatusFound)
|
||||
assert.True(t, strings.HasPrefix(w.Header().Get("Location"), oidcMock.AuthorizationEndpoint()))
|
||||
assert.NotEmpty(t, routeutils.GetOidcState(r))
|
||||
assert.Contains(t, w.Header().Get("Location"), fmt.Sprintf("state=%s", routeutils.GetOidcState(r)))
|
||||
}
|
||||
|
||||
func TestAuthenticateMiddleware_tryHandleOidc_InvalidToken_NonExistingUser(t *testing.T) {
|
||||
const (
|
||||
testProvider = "mock"
|
||||
testSub = "testsub"
|
||||
)
|
||||
|
||||
oidcMock, _ := mockoidc.Run()
|
||||
defer oidcMock.Shutdown()
|
||||
|
||||
cfg := config.Empty()
|
||||
config.Set(cfg)
|
||||
config.WithOidcProvider(cfg, testProvider, oidcMock.ClientID, oidcMock.ClientSecret, oidcMock.Addr()+"/oidc")
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/summary", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
testIdToken := &config.IdTokenPayload{
|
||||
Subject: testSub,
|
||||
Expiry: time.Now().Add(-time.Minute).Unix(),
|
||||
ProviderName: testProvider,
|
||||
}
|
||||
routeutils.SetOidcIdTokenPayload(testIdToken, r, w)
|
||||
|
||||
userServiceMock := new(mocks.UserServiceMock)
|
||||
userServiceMock.On("GetUserByOidc", testProvider, testSub).Return(nil, errors.New(""))
|
||||
|
||||
sut := NewAuthenticateMiddleware(userServiceMock)
|
||||
|
||||
assert.False(t, sut.tryHandleOidc(w, r))
|
||||
assert.NotEqual(t, w.Code, http.StatusTemporaryRedirect)
|
||||
assert.NotEqual(t, w.Code, http.StatusFound)
|
||||
}
|
||||
|
||||
func TestAuthenticateMiddleware_tryHandleOidc_ValidToken(t *testing.T) {
|
||||
config.Set(config.Empty())
|
||||
|
||||
userServiceMock := new(mocks.UserServiceMock)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/summary", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
routeutils.SetOidcIdTokenPayload(&config.IdTokenPayload{
|
||||
Expiry: time.Now().Add(1 * time.Minute).Unix(),
|
||||
}, r, w)
|
||||
|
||||
sut := NewAuthenticateMiddleware(userServiceMock)
|
||||
|
||||
assert.False(t, sut.tryHandleOidc(w, r))
|
||||
assert.NotEqual(t, w.Code, http.StatusTemporaryRedirect)
|
||||
assert.NotEqual(t, w.Code, http.StatusFound)
|
||||
}
|
||||
|
||||
// TODO: somehow test cookie auth function
|
||||
|
||||
@@ -34,6 +34,9 @@ func (m *UserServiceMock) GetUserByResetToken(s string) (*models.User, error) {
|
||||
|
||||
func (m *UserServiceMock) GetUserByOidc(s1, s2 string) (*models.User, error) {
|
||||
args := m.Called(s1, s2)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*models.User), args.Error(1)
|
||||
}
|
||||
|
||||
|
||||
@@ -242,7 +242,7 @@ func (c *SetPasswordRequest) IsValid() bool {
|
||||
func (s *Signup) IsValid() bool {
|
||||
config := conf.Get()
|
||||
|
||||
captchaValid := s.OidcProvider != ""
|
||||
captchaValid := s.OidcProvider != "" || !config.Security.SignupCaptcha
|
||||
if !captchaValid && config.Security.SignupCaptcha {
|
||||
captchaValid = ValidateCaptcha(s.CaptchaId, s.Captcha)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/muety/wakapi/config"
|
||||
"github.com/muety/wakapi/middlewares"
|
||||
"github.com/muety/wakapi/mocks"
|
||||
"github.com/muety/wakapi/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -15,6 +8,15 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/muety/wakapi/config"
|
||||
"github.com/muety/wakapi/middlewares"
|
||||
"github.com/muety/wakapi/mocks"
|
||||
"github.com/muety/wakapi/models"
|
||||
"github.com/muety/wakapi/routes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -61,7 +63,7 @@ func TestBadgeHandler_Get(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/badge/{user}/interval:week/language:go", nil)
|
||||
req = withUrlParam(req, "user", "user1")
|
||||
req = routes.WithUrlParam(req, "user", "user1")
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
@@ -82,7 +84,7 @@ func TestBadgeHandler_Get(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/badge/{user}/interval:last_year/language:go", nil)
|
||||
req = withUrlParam(req, "user", "user1")
|
||||
req = routes.WithUrlParam(req, "user", "user1")
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
@@ -102,7 +104,7 @@ func TestBadgeHandler_Get(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/badge/{user}/interval:year/project:foo", nil)
|
||||
req = withUrlParam(req, "user", "user1")
|
||||
req = routes.WithUrlParam(req, "user", "user1")
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
|
||||
@@ -97,7 +97,7 @@ func (h *LoginHandler) PostLogin(w http.ResponseWriter, r *http.Request) {
|
||||
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w, false).WithError("missing parameters"))
|
||||
return
|
||||
}
|
||||
if err := loginDecoder.Decode(&login, r.PostForm); err != nil {
|
||||
if err := loginDecoder.Decode(&login, r.PostForm); err != nil || login.Username == "" || login.Password == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
templates[conf.LoginTemplate].Execute(w, h.buildViewModel(r, w, false).WithError("missing parameters"))
|
||||
return
|
||||
@@ -128,8 +128,8 @@ func (h *LoginHandler) PostLogout(w http.ResponseWriter, r *http.Request) {
|
||||
if user := middlewares.GetPrincipal(r); user != nil {
|
||||
h.userSrvc.FlushUserCache(user.ID)
|
||||
}
|
||||
routeutils.ClearSession(r, w)
|
||||
http.SetCookie(w, h.config.GetClearCookie(models.AuthCookieKey))
|
||||
routeutils.ClearSession(r, w) // clear all session data
|
||||
http.SetCookie(w, h.config.GetClearCookie(models.AuthCookieKey)) // clear auth token
|
||||
http.Redirect(w, r, fmt.Sprintf("%s/", h.config.Server.BasePath), http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -178,14 +178,16 @@ func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) {
|
||||
var invitedDate time.Time
|
||||
var inviteCodeKey = fmt.Sprintf("%s_%s", conf.KeyInviteCode, signup.InviteCode)
|
||||
|
||||
if kv, _ := h.keyValueSrvc.GetString(inviteCodeKey); kv != nil && kv.Value != "" {
|
||||
if parts := strings.Split(kv.Value, ","); len(parts) == 2 {
|
||||
invitedBy = parts[0]
|
||||
invitedDate, _ = time.Parse(time.RFC3339, parts[1])
|
||||
}
|
||||
if signup.InviteCode != "" {
|
||||
if kv, _ := h.keyValueSrvc.GetString(inviteCodeKey); kv != nil && kv.Value != "" {
|
||||
if parts := strings.Split(kv.Value, ","); len(parts) == 2 {
|
||||
invitedBy = parts[0]
|
||||
invitedDate, _ = time.Parse(time.RFC3339, parts[1])
|
||||
}
|
||||
|
||||
if err := h.keyValueSrvc.DeleteString(inviteCodeKey); err != nil {
|
||||
conf.Log().Error("failed to revoke invite code", "inviteCodeKey", inviteCodeKey, "error", err)
|
||||
if err := h.keyValueSrvc.DeleteString(inviteCodeKey); err != nil {
|
||||
conf.Log().Error("failed to revoke invite code", "inviteCodeKey", inviteCodeKey, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,7 +215,7 @@ func (h *LoginHandler) PostSignup(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
conf.Log().Request(r).Error("failed to create new user", "error", err)
|
||||
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w, h.config.Security.SignupCaptcha).WithError("failed to create new user"))
|
||||
templates[conf.SignupTemplate].Execute(w, h.buildViewModel(r, w, h.config.Security.SignupCaptcha).WithError("failed to create new user (username or e-mail already existing?)"))
|
||||
return
|
||||
}
|
||||
if !created {
|
||||
@@ -361,7 +363,7 @@ func (h *LoginHandler) PostResetPassword(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (h *LoginHandler) GetOidcLogin(w http.ResponseWriter, r *http.Request) {
|
||||
provider := h.getOpenIdConnect(w, r)
|
||||
provider := h.getOidcProvider(w, r)
|
||||
if provider == nil {
|
||||
return // redirect done in previous method
|
||||
}
|
||||
@@ -370,7 +372,7 @@ func (h *LoginHandler) GetOidcLogin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *LoginHandler) GetOidcCallback(w http.ResponseWriter, r *http.Request) {
|
||||
provider := h.getOpenIdConnect(w, r)
|
||||
provider := h.getOidcProvider(w, r)
|
||||
if provider == nil {
|
||||
return // redirect done in previous method
|
||||
}
|
||||
@@ -378,9 +380,12 @@ func (h *LoginHandler) GetOidcCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
// clear any existing id token on the session, just because
|
||||
routeutils.ClearOidcIdTokenPayload(r, w)
|
||||
|
||||
// validate oauth state param
|
||||
savedState := routeutils.GetOidcState(r)
|
||||
if savedState != state {
|
||||
if state == "" || savedState != state {
|
||||
errMsg := "suspicious operation, got invalid state in oidc callback"
|
||||
conf.Log().Request(r).Error(errMsg, "saved_state", savedState, "state", state, "provider", provider.Name)
|
||||
routeutils.SetError(r, w, errMsg)
|
||||
@@ -392,7 +397,9 @@ func (h *LoginHandler) GetOidcCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// exchange auth code for access token and id token
|
||||
authToken, err := provider.OAuth2.Exchange(r.Context(), code)
|
||||
if err != nil {
|
||||
routeutils.SetError(r, w, "failed to exchange authorization code for access token")
|
||||
errMsg := "failed to exchange authorization code for access token"
|
||||
conf.Log().Request(r).Error(errMsg, "provider", provider.Name)
|
||||
routeutils.SetError(r, w, errMsg)
|
||||
http.Redirect(w, r, fmt.Sprintf("%s/login", h.config.Server.BasePath), http.StatusFound)
|
||||
return
|
||||
}
|
||||
@@ -400,7 +407,9 @@ func (h *LoginHandler) GetOidcCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// extract id token
|
||||
rawIdToken, ok := authToken.Extra("id_token").(string)
|
||||
if !ok {
|
||||
routeutils.SetError(r, w, "failed to extract id_token")
|
||||
errMsg := "failed to extract id_token"
|
||||
conf.Log().Request(r).Error(errMsg, "provider", provider.Name)
|
||||
routeutils.SetError(r, w, errMsg)
|
||||
http.Redirect(w, r, fmt.Sprintf("%s/login", h.config.Server.BasePath), http.StatusFound)
|
||||
return
|
||||
}
|
||||
@@ -408,11 +417,12 @@ func (h *LoginHandler) GetOidcCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// verify id token
|
||||
idTokenPayload, err := routeutils.DecodeOidcIdToken(rawIdToken, provider, r.Context())
|
||||
if err != nil || idTokenPayload == nil {
|
||||
routeutils.SetError(r, w, "failed to verify and decode id_token")
|
||||
errMsg := "failed to verify and decode id_token"
|
||||
conf.Log().Request(r).Error(errMsg, "provider", provider.Name, "id_token", rawIdToken) // save to log, because does not grant any access
|
||||
routeutils.SetError(r, w, errMsg)
|
||||
http.Redirect(w, r, fmt.Sprintf("%s/login", h.config.Server.BasePath), http.StatusFound)
|
||||
return
|
||||
}
|
||||
routeutils.SetOidcIdTokenPayload(idTokenPayload, r, w) // save to session, only used by middleware for automatic redirection upon expiry
|
||||
|
||||
user, err := h.userSrvc.GetUserByOidc(provider.Name, idTokenPayload.Subject)
|
||||
if err != nil {
|
||||
@@ -431,11 +441,11 @@ func (h *LoginHandler) GetOidcCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if newUsername := h.coalesceExistingUser(signup.Username); newUsername != signup.Username {
|
||||
conf.Log().Request(r).Warn("username from id token already exist, using suffixed one instead", "username", newUsername)
|
||||
slog.Warn("username from id token already exist, using suffixed one instead", "username", newUsername)
|
||||
signup.Username = newUsername
|
||||
}
|
||||
|
||||
conf.Log().Request(r).Info("creating new user from successful oidc authentication",
|
||||
slog.Info("creating new user from successful oidc authentication",
|
||||
"provider", signup.OidcProvider,
|
||||
"username", signup.Username,
|
||||
"email", signup.Email,
|
||||
@@ -445,13 +455,14 @@ func (h *LoginHandler) GetOidcCallback(w http.ResponseWriter, r *http.Request) {
|
||||
newUser, created, err := h.userSrvc.CreateOrGet(signup, false)
|
||||
if err != nil || !created {
|
||||
conf.Log().Request(r).Error("failed to create new user", "error", err)
|
||||
routeutils.SetError(r, w, "failed to create new user")
|
||||
routeutils.SetError(r, w, "failed to create new user (username or e-mail already existing?)")
|
||||
http.Redirect(w, r, fmt.Sprintf("%s/login", h.config.Server.BasePath), http.StatusFound)
|
||||
return
|
||||
}
|
||||
user = newUser
|
||||
}
|
||||
|
||||
routeutils.SetOidcIdTokenPayload(idTokenPayload, r, w) // save to session, only used by middleware for automatic redirection upon expiry
|
||||
h.finishUserLogin(user, r, w)
|
||||
http.Redirect(w, r, fmt.Sprintf("%s/summary", h.config.Server.BasePath), http.StatusFound)
|
||||
}
|
||||
@@ -474,7 +485,7 @@ func (h *LoginHandler) buildViewModel(r *http.Request, w http.ResponseWriter, wi
|
||||
return routeutils.WithSessionMessages(vm, r, w)
|
||||
}
|
||||
|
||||
func (h *LoginHandler) getOpenIdConnect(w http.ResponseWriter, r *http.Request) *conf.OidcProvider {
|
||||
func (h *LoginHandler) getOidcProvider(w http.ResponseWriter, r *http.Request) *conf.OidcProvider {
|
||||
providerName := chi.URLParam(r, "provider")
|
||||
provider, err := conf.GetOidcProvider(providerName)
|
||||
if err != nil {
|
||||
|
||||
485
routes/login_test.go
Normal file
485
routes/login_test.go
Normal file
@@ -0,0 +1,485 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/muety/wakapi/config"
|
||||
"github.com/muety/wakapi/mocks"
|
||||
"github.com/muety/wakapi/models"
|
||||
routeutils "github.com/muety/wakapi/routes/utils"
|
||||
"github.com/muety/wakapi/utils"
|
||||
"github.com/oauth2-proxy/mockoidc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type LoginHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
TestUser *models.User
|
||||
OidcMock *mockoidc.MockOIDC
|
||||
UserService *mocks.UserServiceMock
|
||||
KeyValueService *mocks.KeyValueServiceMock
|
||||
Cfg *config.Config
|
||||
Sut *LoginHandler
|
||||
OidcUserNew *mockoidc.MockUser
|
||||
OidcUserExisting *mockoidc.MockUser
|
||||
oidcMockDefaultConfig mockoidc.Config
|
||||
}
|
||||
|
||||
const (
|
||||
testProvider = "mock"
|
||||
testOauthCode = "some-code"
|
||||
testOauthState = "some-state"
|
||||
testUserExistingId = "user1"
|
||||
testUserExistingEmail = "foo@example.org"
|
||||
testUserExistingSub = "111"
|
||||
testUserExistingPassword = "supersecret"
|
||||
testUserNewId = "user2"
|
||||
testUserNewEmail = "bar@example.org"
|
||||
testUserNewSub = "222"
|
||||
testUserNewPassword = "ssssshhhhhh"
|
||||
testPasswordSalt = "salty"
|
||||
)
|
||||
|
||||
func (suite *LoginHandlerTestSuite) SetupSuite() {
|
||||
if m, err := mockoidc.Run(); err == nil {
|
||||
suite.OidcMock = m
|
||||
suite.oidcMockDefaultConfig = *suite.OidcMock.Config()
|
||||
}
|
||||
|
||||
testUserPassword, _ := utils.HashPassword(testUserExistingPassword, testPasswordSalt)
|
||||
|
||||
suite.OidcUserNew = &mockoidc.MockUser{
|
||||
Subject: testUserNewSub,
|
||||
Email: testUserNewEmail,
|
||||
PreferredUsername: testUserNewId,
|
||||
}
|
||||
|
||||
suite.OidcUserExisting = &mockoidc.MockUser{
|
||||
Subject: testUserExistingSub,
|
||||
Email: testUserExistingEmail,
|
||||
PreferredUsername: testUserExistingId,
|
||||
}
|
||||
|
||||
suite.TestUser = &models.User{
|
||||
ID: testUserExistingId,
|
||||
Email: testUserExistingEmail,
|
||||
AuthType: testProvider,
|
||||
Sub: testUserExistingSub,
|
||||
Password: testUserPassword,
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TearDownSuite() {
|
||||
suite.OidcMock.Shutdown()
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) BeforeTest(suiteName, testName string) {
|
||||
suite.UserService = new(mocks.UserServiceMock)
|
||||
suite.KeyValueService = new(mocks.KeyValueServiceMock)
|
||||
|
||||
cfg := config.Empty()
|
||||
cfg.Security.SecureCookie = securecookie.New(
|
||||
securecookie.GenerateRandomKey(64),
|
||||
securecookie.GenerateRandomKey(32),
|
||||
)
|
||||
cfg.Security.PasswordSalt = testPasswordSalt
|
||||
config.Set(cfg)
|
||||
suite.Cfg = cfg
|
||||
|
||||
suite.resetOidcMockTtl()
|
||||
suite.setupOidcProvider(testProvider)
|
||||
|
||||
suite.Sut = NewLoginHandler(suite.UserService, nil, suite.KeyValueService)
|
||||
Init() // load templates
|
||||
}
|
||||
|
||||
func TestLoginHandlerTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(LoginHandlerTestSuite))
|
||||
}
|
||||
|
||||
// Test cases
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestPostLogin_Success() {
|
||||
form := url.Values{}
|
||||
form.Add("username", testUserExistingId)
|
||||
form.Add("password", testUserExistingPassword)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.UserService.On("GetUserById", testUserExistingId).Return(suite.TestUser, nil)
|
||||
suite.UserService.On("Update", mock.Anything).Return(suite.TestUser, nil)
|
||||
|
||||
suite.Sut.PostLogin(w, r)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Equal(suite.T(), "/summary", w.Header().Get("Location"))
|
||||
assert.Contains(suite.T(), w.Header().Get("Set-Cookie"), "wakapi_auth=")
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestPostLogin_ValidAuthCookie() {
|
||||
// TODO: implement this
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestPostLogin_EmptyLoginForm() {
|
||||
form := url.Values{}
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.UserService.On("Count").Return(1, nil)
|
||||
|
||||
suite.Sut.PostLogin(w, r)
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusBadRequest, w.Code)
|
||||
assert.Contains(suite.T(), string(body), "Missing parameters")
|
||||
assert.Empty(suite.T(), w.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestPostLogin_NonExistingUser() {
|
||||
form := url.Values{}
|
||||
form.Add("username", "nonexisting")
|
||||
form.Add("password", testUserExistingPassword)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.UserService.On("GetUserById", "nonexisting").Return(nil, errors.New(""))
|
||||
suite.UserService.On("Count").Return(1, nil)
|
||||
|
||||
suite.Sut.PostLogin(w, r)
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusNotFound, w.Code)
|
||||
assert.Contains(suite.T(), string(body), "Resource not found")
|
||||
assert.Empty(suite.T(), w.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestPostLogin_WrongPassword() {
|
||||
form := url.Values{}
|
||||
form.Add("username", testUserExistingId)
|
||||
form.Add("password", "wrongpassword")
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.UserService.On("GetUserById", testUserExistingId).Return(suite.TestUser, nil)
|
||||
suite.UserService.On("Count").Return(1, nil)
|
||||
|
||||
suite.Sut.PostLogin(w, r)
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(suite.T(), string(body), "Invalid credentials")
|
||||
assert.Empty(suite.T(), w.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestPostSignup_Success() {
|
||||
form := url.Values{}
|
||||
form.Add("username", testUserNewId)
|
||||
form.Add("email", testUserNewEmail)
|
||||
form.Add("password", testUserNewPassword)
|
||||
form.Add("password_repeat", testUserNewPassword)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
||||
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.UserService.On("Count", mock.Anything).Return(1, nil)
|
||||
suite.UserService.On("CreateOrGet", mock.Anything, mock.Anything).Return(&models.User{}, true, nil)
|
||||
suite.Cfg.Security.AllowSignup = true
|
||||
|
||||
suite.Sut.PostSignup(w, r)
|
||||
|
||||
argSignup := suite.UserService.Calls[1].Arguments[0].(*models.Signup)
|
||||
argIsAdmin := suite.UserService.Calls[1].Arguments[1].(bool)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Equal(suite.T(), testUserNewId, argSignup.Username)
|
||||
assert.Equal(suite.T(), testUserNewEmail, argSignup.Email)
|
||||
assert.Equal(suite.T(), testUserNewPassword, argSignup.Password)
|
||||
assert.False(suite.T(), argIsAdmin)
|
||||
assert.Equal(suite.T(), "/", w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestPostSignup_InvalidForm() {
|
||||
form := url.Values{}
|
||||
form.Add("username", "")
|
||||
form.Add("password", testUserNewPassword)
|
||||
form.Add("password_repeat", testUserNewPassword)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
||||
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.UserService.On("Count", mock.Anything).Return(1, nil)
|
||||
suite.Cfg.Security.AllowSignup = true
|
||||
|
||||
suite.Sut.PostSignup(w, r)
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusBadRequest, w.Code)
|
||||
assert.Contains(suite.T(), string(body), "User name is invalid")
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestPostSignup_ExistingUser() {
|
||||
form := url.Values{}
|
||||
form.Add("username", testUserExistingId)
|
||||
form.Add("password", testUserExistingPassword)
|
||||
form.Add("password_repeat", testUserExistingPassword)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
||||
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.UserService.On("Count", mock.Anything).Return(1, nil)
|
||||
suite.UserService.On("CreateOrGet", mock.Anything, mock.Anything).Return(suite.TestUser, false, nil)
|
||||
suite.Cfg.Security.AllowSignup = true
|
||||
|
||||
suite.Sut.PostSignup(w, r)
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusConflict, w.Code)
|
||||
assert.Contains(suite.T(), string(body), "User already existing")
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestPostSignup_SignupDisabled() {
|
||||
form := url.Values{}
|
||||
form.Add("username", testUserNewId)
|
||||
form.Add("password", testUserNewPassword)
|
||||
form.Add("password_repeat", testUserNewPassword)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(form.Encode()))
|
||||
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.UserService.On("Count", mock.Anything).Return(1, nil)
|
||||
|
||||
suite.Sut.PostSignup(w, r)
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusForbidden, w.Code)
|
||||
assert.Contains(suite.T(), string(body), "Registration is disabled on this server")
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestGetOidcLogin_Redirect() {
|
||||
r := httptest.NewRequest(http.MethodGet, "/oidc/{provider}/login", nil)
|
||||
r = WithUrlParam(r, "provider", testProvider)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.Sut.GetOidcLogin(w, r)
|
||||
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.True(suite.T(), strings.HasPrefix(w.Header().Get("Location"), suite.OidcMock.AuthorizationEndpoint()))
|
||||
assert.Contains(suite.T(), w.Header().Get("Location"), fmt.Sprintf("state=%s", routeutils.GetOidcState(r)))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestGetOidcLogin_NoMatchingProvider() {
|
||||
r := httptest.NewRequest(http.MethodGet, "/oidc/{provider}/login", nil)
|
||||
r = WithUrlParam(r, "provider", "mock2")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.Sut.GetOidcCallback(w, r)
|
||||
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Equal(suite.T(), "/login", w.Header().Get("Location"))
|
||||
assert.Equal(suite.T(), "oidc provider \"mock2\" not registered", suite.getSessionError(r))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestGetOidcLoginCallback_Success() {
|
||||
url := suite.authorizeUser(suite.OidcUserExisting)
|
||||
r := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
r = WithUrlParam(r, "provider", testProvider)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
routeutils.SetOidcState(testOauthState, r, w)
|
||||
suite.UserService.On("GetUserByOidc", testProvider, suite.OidcUserExisting.Subject).Return(suite.TestUser, nil)
|
||||
suite.UserService.On("Update", mock.Anything).Return(suite.TestUser, nil)
|
||||
|
||||
suite.Sut.GetOidcCallback(w, r)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Empty(suite.T(), suite.getSessionError(r))
|
||||
assert.Equal(suite.T(), "/summary", w.Header().Get("Location"))
|
||||
assert.Contains(suite.T(), w.Header().Get("Set-Cookie"), "wakapi_auth=")
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestGetOidcLoginCallback_Success_CreateUser() {
|
||||
suite.Cfg.Security.AllowSignup = true
|
||||
|
||||
url := suite.authorizeUser(suite.OidcUserNew)
|
||||
r := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
r = WithUrlParam(r, "provider", testProvider)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
routeutils.SetOidcState(testOauthState, r, w)
|
||||
suite.UserService.On("GetUserByOidc", testProvider, suite.OidcUserNew.Subject).Return(nil, errors.New(""))
|
||||
suite.UserService.On("GetUserById", suite.OidcUserNew.PreferredUsername).Return(nil, errors.New(""))
|
||||
suite.UserService.On("CreateOrGet", mock.Anything, mock.Anything).Return(suite.TestUser, true, nil)
|
||||
suite.UserService.On("Update", mock.Anything).Return(suite.TestUser, nil)
|
||||
|
||||
suite.Sut.GetOidcCallback(w, r)
|
||||
|
||||
argSignup := suite.UserService.Calls[2].Arguments[0].(*models.Signup)
|
||||
argIsAdmin := suite.UserService.Calls[2].Arguments[1].(bool)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), suite.OidcUserNew.PreferredUsername, argSignup.Username)
|
||||
assert.Equal(suite.T(), suite.OidcUserNew.Email, argSignup.Email)
|
||||
assert.Equal(suite.T(), suite.OidcUserNew.Subject, argSignup.OidcSubject)
|
||||
assert.Equal(suite.T(), testProvider, argSignup.OidcProvider)
|
||||
assert.NotEmpty(suite.T(), argSignup.Password)
|
||||
assert.False(suite.T(), argIsAdmin)
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Empty(suite.T(), suite.getSessionError(r))
|
||||
assert.Equal(suite.T(), "/summary", w.Header().Get("Location"))
|
||||
assert.Contains(suite.T(), w.Header().Get("Set-Cookie"), "wakapi_auth=")
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestGetOidcLoginCallback_SignupDisabled() {
|
||||
url := suite.authorizeUser(suite.OidcUserNew)
|
||||
r := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
r = WithUrlParam(r, "provider", testProvider)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
routeutils.SetOidcState(testOauthState, r, w)
|
||||
suite.UserService.On("GetUserByOidc", testProvider, suite.OidcUserNew.Subject).Return(nil, errors.New(""))
|
||||
|
||||
suite.Sut.GetOidcCallback(w, r)
|
||||
|
||||
suite.UserService.AssertExpectations(suite.T())
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Equal(suite.T(), "registration is disabled on this server", suite.getSessionError(r))
|
||||
assert.Equal(suite.T(), "/login", w.Header().Get("Location"))
|
||||
assert.Empty(suite.T(), w.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestGetOidcLoginCallback_InvalidState() {
|
||||
url := suite.authorizeUser(suite.OidcUserNew)
|
||||
r := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
r = WithUrlParam(r, "provider", testProvider)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
routeutils.SetOidcState("some-other-state", r, w)
|
||||
|
||||
suite.Sut.GetOidcCallback(w, r)
|
||||
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Equal(suite.T(), "suspicious operation, got invalid state in oidc callback", suite.getSessionError(r))
|
||||
assert.Equal(suite.T(), "/login", w.Header().Get("Location"))
|
||||
assert.Empty(suite.T(), w.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestGetOidcLoginCallback_AuthExchangeFailure() {
|
||||
url := suite.authorizeUser(suite.OidcUserNew)
|
||||
r := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
r = WithUrlParam(r, "provider", testProvider)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
routeutils.SetOidcState(testOauthState, r, w)
|
||||
|
||||
// token endpoint will be called twice, see https://github.com/golang/oauth2/blob/792c8776358f0c8689d84eef0d0c966937d560fb/internal/token.go#L231-L243
|
||||
suite.OidcMock.QueueError(&mockoidc.ServerError{Code: http.StatusInternalServerError})
|
||||
suite.OidcMock.QueueError(&mockoidc.ServerError{Code: http.StatusInternalServerError})
|
||||
|
||||
suite.Sut.GetOidcCallback(w, r)
|
||||
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Equal(suite.T(), "failed to exchange authorization code for access token", suite.getSessionError(r))
|
||||
assert.Equal(suite.T(), "/login", w.Header().Get("Location"))
|
||||
assert.Empty(suite.T(), w.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestGetOidcLoginCallback_IdTokenExpired() {
|
||||
url := suite.authorizeUser(suite.OidcUserNew)
|
||||
r := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
r = WithUrlParam(r, "provider", testProvider)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
routeutils.SetOidcState(testOauthState, r, w)
|
||||
|
||||
suite.OidcMock.AccessTTL = 0 // related: https://github.com/oauth2-proxy/mockoidc/issues/38
|
||||
|
||||
suite.Sut.GetOidcCallback(w, r)
|
||||
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Equal(suite.T(), "failed to verify and decode id_token", suite.getSessionError(r))
|
||||
assert.Equal(suite.T(), "/login", w.Header().Get("Location"))
|
||||
assert.Empty(suite.T(), w.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) TestGetOidcLoginCallback_NoMatchingProvider() {
|
||||
r := httptest.NewRequest(http.MethodGet, "/oidc/{provider}/callback", nil)
|
||||
r = WithUrlParam(r, "provider", "mock2")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.Sut.GetOidcLogin(w, r)
|
||||
|
||||
assert.Equal(suite.T(), http.StatusFound, w.Code)
|
||||
assert.Equal(suite.T(), "/login", w.Header().Get("Location"))
|
||||
assert.Equal(suite.T(), "oidc provider \"mock2\" not registered", suite.getSessionError(r))
|
||||
}
|
||||
|
||||
// Private utility methods
|
||||
func (suite *LoginHandlerTestSuite) setupOidcProvider(name string) {
|
||||
config.WithOidcProvider(suite.Cfg, name, suite.OidcMock.ClientID, suite.OidcMock.ClientSecret, suite.OidcMock.Addr()+"/oidc")
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) getSessionError(r *http.Request) string {
|
||||
session, _ := config.GetSessionStore().Get(r, config.CookieKeySession)
|
||||
if errors := session.Flashes("error"); len(errors) > 0 {
|
||||
return errors[0].(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) authorizeUser(user *mockoidc.MockUser) string { // returns the location header's redirect url
|
||||
r := httptest.NewRequest(http.MethodGet, suite.OidcMock.AuthorizationEndpoint(), nil)
|
||||
q := r.URL.Query()
|
||||
q.Set("code", testOauthCode)
|
||||
q.Set("client_id", suite.OidcMock.ClientID)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid profile email")
|
||||
q.Set("state", testOauthState)
|
||||
q.Set("redirect_uri", fmt.Sprintf("/oidc/%s/callback", testProvider))
|
||||
r.URL.RawQuery = q.Encode()
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
suite.OidcMock.QueueUser(user)
|
||||
suite.OidcMock.QueueCode(testOauthCode)
|
||||
|
||||
suite.OidcMock.Authorize(w, r)
|
||||
return w.Header().Get("Location")
|
||||
}
|
||||
|
||||
func (suite *LoginHandlerTestSuite) resetOidcMockTtl() {
|
||||
suite.OidcMock.AccessTTL = 600 * time.Second
|
||||
suite.OidcMock.RefreshTTL = 60 * time.Minute
|
||||
}
|
||||
|
||||
// TODO: test all remaining endpoints
|
||||
@@ -1,13 +1,15 @@
|
||||
package api
|
||||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func withUrlParam(r *http.Request, key, value string) *http.Request {
|
||||
// https://github.com/go-chi/chi/issues/76#issuecomment-370145140
|
||||
func WithUrlParam(r *http.Request, key, value string) *http.Request {
|
||||
r.URL.RawPath = strings.Replace(r.URL.RawPath, "{"+key+"}", value, 1)
|
||||
r.URL.Path = strings.Replace(r.URL.Path, "{"+key+"}", value, 1)
|
||||
rctx := chi.NewRouteContext()
|
||||
@@ -3,6 +3,7 @@ package utils
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/duke-git/lancet/v2/random"
|
||||
@@ -65,6 +66,12 @@ func GetOidcIdTokenPayload(r *http.Request) *conf.IdTokenPayload {
|
||||
return &payload
|
||||
}
|
||||
|
||||
func ClearOidcIdTokenPayload(r *http.Request, w http.ResponseWriter) {
|
||||
session, _ := conf.GetSessionStore().Get(r, conf.CookieKeySession)
|
||||
delete(session.Values, conf.SessionValueOidcIdTokenPayload)
|
||||
session.Save(r, w)
|
||||
}
|
||||
|
||||
func DecodeOidcIdToken(token string, provider *conf.OidcProvider, ctx context.Context) (*conf.IdTokenPayload, error) {
|
||||
idToken, err := provider.Verifier.Verify(ctx, token)
|
||||
if err != nil {
|
||||
@@ -72,8 +79,10 @@ func DecodeOidcIdToken(token string, provider *conf.OidcProvider, ctx context.Co
|
||||
}
|
||||
|
||||
var payload conf.IdTokenPayload
|
||||
if err := idToken.Claims(&payload); err != nil || !payload.IsValid() {
|
||||
if err := idToken.Claims(&payload); err != nil {
|
||||
return nil, err
|
||||
} else if !payload.IsValid() {
|
||||
return nil, errors.New("invalid oidc id token payload")
|
||||
}
|
||||
payload.ProviderName = provider.Name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user