mirror of
https://github.com/muety/wakapi.git
synced 2025-12-05 22:20:24 -08:00
243 lines
6.4 KiB
Go
243 lines
6.4 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/oauth2-proxy/mockoidc"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// TODO: add more tests, including yaml- and env. parsing, validation, etc.
|
|
|
|
func Test_Load_OidcProviders(t *testing.T) {
|
|
oidcMock1, _ := mockoidc.Run()
|
|
defer oidcMock1.Shutdown()
|
|
oidcMock2, _ := mockoidc.Run()
|
|
defer oidcMock2.Shutdown()
|
|
|
|
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_NAME", "testprovider1")
|
|
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_DISPLAY_NAME", "Test Provider 1")
|
|
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_CLIENT_ID", oidcMock1.ClientID)
|
|
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_CLIENT_SECRET", oidcMock1.ClientSecret)
|
|
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_ENDPOINT", oidcMock1.Addr()+"/oidc")
|
|
os.Setenv("WAKAPI_OIDC_PROVIDERS_1_NAME", "testprovider2")
|
|
os.Setenv("WAKAPI_OIDC_PROVIDERS_1_CLIENT_ID", oidcMock2.ClientID)
|
|
os.Setenv("WAKAPI_OIDC_PROVIDERS_1_CLIENT_SECRET", oidcMock2.ClientSecret)
|
|
os.Setenv("WAKAPI_OIDC_PROVIDERS_1_ENDPOINT", oidcMock2.Addr()+"/oidc")
|
|
|
|
cfg := Load("", "")
|
|
oidcCfg := cfg.Security.OidcProviders
|
|
|
|
assert.Len(t, oidcCfg, 2)
|
|
assert.Equal(t, "testprovider1", oidcCfg[0].Name)
|
|
assert.Equal(t, "Test Provider 1", oidcCfg[0].DisplayName)
|
|
assert.Equal(t, "Test Provider 1", oidcCfg[0].String())
|
|
assert.Equal(t, oidcMock1.ClientID, oidcCfg[0].ClientID)
|
|
assert.Equal(t, oidcMock1.ClientSecret, oidcCfg[0].ClientSecret)
|
|
assert.Equal(t, oidcMock1.Addr()+"/oidc", oidcCfg[0].Endpoint)
|
|
assert.Equal(t, "testprovider2", oidcCfg[1].Name)
|
|
assert.Equal(t, "", oidcCfg[1].DisplayName)
|
|
assert.Equal(t, "Testprovider2", oidcCfg[1].String())
|
|
assert.Equal(t, oidcMock2.ClientID, oidcCfg[1].ClientID)
|
|
assert.Equal(t, oidcMock2.ClientSecret, oidcCfg[1].ClientSecret)
|
|
assert.Equal(t, oidcMock2.Addr()+"/oidc", oidcCfg[1].Endpoint)
|
|
|
|
p1, err1 := GetOidcProvider("testprovider1")
|
|
assert.Nil(t, err1)
|
|
assert.Equal(t, "Test Provider 1", p1.DisplayName)
|
|
|
|
p2, err2 := GetOidcProvider("testprovider2")
|
|
assert.Nil(t, err2)
|
|
assert.Equal(t, "Testprovider2", p2.DisplayName)
|
|
}
|
|
|
|
func TestOidcProviderConfig_Validate(t *testing.T) {
|
|
// note: test cases were generated by ai
|
|
testCases := []struct {
|
|
name string
|
|
config oidcProviderConfig
|
|
err string
|
|
}{
|
|
{
|
|
name: "valid",
|
|
config: oidcProviderConfig{
|
|
Name: "test-provider-1",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
Endpoint: "https://provider.com/oidc",
|
|
},
|
|
err: "",
|
|
},
|
|
{
|
|
name: "valid with http",
|
|
config: oidcProviderConfig{
|
|
Name: "test-provider-1",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
Endpoint: "http://provider.com/oidc",
|
|
},
|
|
err: "",
|
|
},
|
|
{
|
|
name: "invalid name with spaces",
|
|
config: oidcProviderConfig{
|
|
Name: "test provider",
|
|
},
|
|
err: "invalid provider name 'test provider', must only contain alphanumeric characters or '-'",
|
|
},
|
|
{
|
|
name: "invalid name with underscore",
|
|
config: oidcProviderConfig{
|
|
Name: "test_provider",
|
|
},
|
|
err: "invalid provider name 'test_provider', must only contain alphanumeric characters or '-'",
|
|
},
|
|
{
|
|
name: "missing client id",
|
|
config: oidcProviderConfig{
|
|
Name: "test-provider",
|
|
ClientSecret: "client-secret",
|
|
Endpoint: "https://provider.com/oidc",
|
|
},
|
|
err: "provider 'test-provider' is missing client id",
|
|
},
|
|
{
|
|
name: "missing client secret",
|
|
config: oidcProviderConfig{
|
|
Name: "test-provider",
|
|
ClientID: "client-id",
|
|
Endpoint: "https://provider.com/oidc",
|
|
},
|
|
err: "provider 'test-provider' is missing client secret",
|
|
},
|
|
{
|
|
name: "missing endpoint",
|
|
config: oidcProviderConfig{
|
|
Name: "test-provider",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
},
|
|
err: "provider 'test-provider' is missing endpoint",
|
|
},
|
|
{
|
|
name: "invalid endpoint scheme",
|
|
config: oidcProviderConfig{
|
|
Name: "test-provider",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
Endpoint: "ftp://provider.com/oidc",
|
|
},
|
|
err: "provider 'test-provider' is missing endpoint",
|
|
},
|
|
{
|
|
name: "endpoint without scheme",
|
|
config: oidcProviderConfig{
|
|
Name: "test-provider",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
Endpoint: "provider.com/oidc",
|
|
},
|
|
err: "provider 'test-provider' is missing endpoint",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := tc.config.Validate()
|
|
if tc.err == "" {
|
|
assert.NoError(t, err)
|
|
} else {
|
|
assert.EqualError(t, err, tc.err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConfig_IsDev(t *testing.T) {
|
|
assert.True(t, IsDev("dev"))
|
|
assert.True(t, IsDev("development"))
|
|
assert.False(t, IsDev("prod"))
|
|
assert.False(t, IsDev("production"))
|
|
assert.False(t, IsDev("anything else"))
|
|
}
|
|
|
|
func Test_mysqlConnectionString(t *testing.T) {
|
|
c := &dbConfig{
|
|
Host: "test_host",
|
|
Port: 9999,
|
|
User: "test_user",
|
|
Password: "test_password",
|
|
Name: "test_name",
|
|
Dialect: "mysql",
|
|
Charset: "utf8mb4",
|
|
MaxConn: 10,
|
|
Compress: true,
|
|
}
|
|
|
|
assert.Equal(t, fmt.Sprintf(
|
|
"%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=%s&compress=true&sql_mode=ANSI_QUOTES",
|
|
c.User,
|
|
c.Password,
|
|
c.Host,
|
|
c.Port,
|
|
c.Name,
|
|
"Local",
|
|
), mysqlConnectionString(c))
|
|
}
|
|
|
|
func Test_mysqlConnectionStringSocket(t *testing.T) {
|
|
c := &dbConfig{
|
|
Socket: "/var/run/mysql.sock",
|
|
Port: 9999,
|
|
User: "test_user",
|
|
Password: "test_password",
|
|
Name: "test_name",
|
|
Dialect: "mysql",
|
|
Charset: "utf8mb4",
|
|
MaxConn: 10,
|
|
Compress: true,
|
|
}
|
|
|
|
assert.Equal(t, fmt.Sprintf(
|
|
"%s:%s@unix(%s)/%s?charset=utf8mb4&parseTime=true&loc=%s&compress=true&sql_mode=ANSI_QUOTES",
|
|
c.User,
|
|
c.Password,
|
|
c.Socket,
|
|
c.Name,
|
|
"Local",
|
|
), mysqlConnectionString(c))
|
|
}
|
|
|
|
func Test_postgresConnectionString(t *testing.T) {
|
|
c := &dbConfig{
|
|
Host: "test_host",
|
|
Port: 9999,
|
|
User: "test_user",
|
|
Password: "test_password",
|
|
Name: "test_name",
|
|
Dialect: "postgres",
|
|
MaxConn: 10,
|
|
}
|
|
|
|
assert.Equal(t, fmt.Sprintf(
|
|
"host=%s port=%d user=%s dbname=%s password=%s sslmode=disable",
|
|
c.Host,
|
|
c.Port,
|
|
c.User,
|
|
c.Name,
|
|
c.Password,
|
|
), postgresConnectionString(c))
|
|
}
|
|
|
|
func Test_sqliteConnectionString(t *testing.T) {
|
|
c := &dbConfig{
|
|
Name: "test_name",
|
|
Dialect: "sqlite3",
|
|
}
|
|
assert.True(t, strings.HasPrefix(sqliteConnectionString(c), c.Name))
|
|
assert.Contains(t, strings.ToLower(sqliteConnectionString(c)), "journal_mode=wal")
|
|
}
|