Files
wakapi/config/config_test.go

243 lines
6.4 KiB
Go

package config
import (
"fmt"
"os"
"strings"
"testing"
"github.com/oauth2-proxy/mockoidc"
"github.com/stretchr/testify/assert"
)
// TODO: add more tests, including yaml- and env. parsing, validation, etc.
func Test_Load_OidcProviders(t *testing.T) {
oidcMock1, _ := mockoidc.Run()
defer oidcMock1.Shutdown()
oidcMock2, _ := mockoidc.Run()
defer oidcMock2.Shutdown()
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_NAME", "testprovider1")
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_DISPLAY_NAME", "Test Provider 1")
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_CLIENT_ID", oidcMock1.ClientID)
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_CLIENT_SECRET", oidcMock1.ClientSecret)
os.Setenv("WAKAPI_OIDC_PROVIDERS_0_ENDPOINT", oidcMock1.Addr()+"/oidc")
os.Setenv("WAKAPI_OIDC_PROVIDERS_1_NAME", "testprovider2")
os.Setenv("WAKAPI_OIDC_PROVIDERS_1_CLIENT_ID", oidcMock2.ClientID)
os.Setenv("WAKAPI_OIDC_PROVIDERS_1_CLIENT_SECRET", oidcMock2.ClientSecret)
os.Setenv("WAKAPI_OIDC_PROVIDERS_1_ENDPOINT", oidcMock2.Addr()+"/oidc")
cfg := Load("", "")
oidcCfg := cfg.Security.OidcProviders
assert.Len(t, oidcCfg, 2)
assert.Equal(t, "testprovider1", oidcCfg[0].Name)
assert.Equal(t, "Test Provider 1", oidcCfg[0].DisplayName)
assert.Equal(t, "Test Provider 1", oidcCfg[0].String())
assert.Equal(t, oidcMock1.ClientID, oidcCfg[0].ClientID)
assert.Equal(t, oidcMock1.ClientSecret, oidcCfg[0].ClientSecret)
assert.Equal(t, oidcMock1.Addr()+"/oidc", oidcCfg[0].Endpoint)
assert.Equal(t, "testprovider2", oidcCfg[1].Name)
assert.Equal(t, "", oidcCfg[1].DisplayName)
assert.Equal(t, "Testprovider2", oidcCfg[1].String())
assert.Equal(t, oidcMock2.ClientID, oidcCfg[1].ClientID)
assert.Equal(t, oidcMock2.ClientSecret, oidcCfg[1].ClientSecret)
assert.Equal(t, oidcMock2.Addr()+"/oidc", oidcCfg[1].Endpoint)
p1, err1 := GetOidcProvider("testprovider1")
assert.Nil(t, err1)
assert.Equal(t, "Test Provider 1", p1.DisplayName)
p2, err2 := GetOidcProvider("testprovider2")
assert.Nil(t, err2)
assert.Equal(t, "Testprovider2", p2.DisplayName)
}
func TestOidcProviderConfig_Validate(t *testing.T) {
// note: test cases were generated by ai
testCases := []struct {
name string
config oidcProviderConfig
err string
}{
{
name: "valid",
config: oidcProviderConfig{
Name: "test-provider-1",
ClientID: "client-id",
ClientSecret: "client-secret",
Endpoint: "https://provider.com/oidc",
},
err: "",
},
{
name: "valid with http",
config: oidcProviderConfig{
Name: "test-provider-1",
ClientID: "client-id",
ClientSecret: "client-secret",
Endpoint: "http://provider.com/oidc",
},
err: "",
},
{
name: "invalid name with spaces",
config: oidcProviderConfig{
Name: "test provider",
},
err: "invalid provider name 'test provider', must only contain alphanumeric characters or '-'",
},
{
name: "invalid name with underscore",
config: oidcProviderConfig{
Name: "test_provider",
},
err: "invalid provider name 'test_provider', must only contain alphanumeric characters or '-'",
},
{
name: "missing client id",
config: oidcProviderConfig{
Name: "test-provider",
ClientSecret: "client-secret",
Endpoint: "https://provider.com/oidc",
},
err: "provider 'test-provider' is missing client id",
},
{
name: "missing client secret",
config: oidcProviderConfig{
Name: "test-provider",
ClientID: "client-id",
Endpoint: "https://provider.com/oidc",
},
err: "provider 'test-provider' is missing client secret",
},
{
name: "missing endpoint",
config: oidcProviderConfig{
Name: "test-provider",
ClientID: "client-id",
ClientSecret: "client-secret",
},
err: "provider 'test-provider' is missing endpoint",
},
{
name: "invalid endpoint scheme",
config: oidcProviderConfig{
Name: "test-provider",
ClientID: "client-id",
ClientSecret: "client-secret",
Endpoint: "ftp://provider.com/oidc",
},
err: "provider 'test-provider' is missing endpoint",
},
{
name: "endpoint without scheme",
config: oidcProviderConfig{
Name: "test-provider",
ClientID: "client-id",
ClientSecret: "client-secret",
Endpoint: "provider.com/oidc",
},
err: "provider 'test-provider' is missing endpoint",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.config.Validate()
if tc.err == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, tc.err)
}
})
}
}
func TestConfig_IsDev(t *testing.T) {
assert.True(t, IsDev("dev"))
assert.True(t, IsDev("development"))
assert.False(t, IsDev("prod"))
assert.False(t, IsDev("production"))
assert.False(t, IsDev("anything else"))
}
func Test_mysqlConnectionString(t *testing.T) {
c := &dbConfig{
Host: "test_host",
Port: 9999,
User: "test_user",
Password: "test_password",
Name: "test_name",
Dialect: "mysql",
Charset: "utf8mb4",
MaxConn: 10,
Compress: true,
}
assert.Equal(t, fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=%s&compress=true&sql_mode=ANSI_QUOTES",
c.User,
c.Password,
c.Host,
c.Port,
c.Name,
"Local",
), mysqlConnectionString(c))
}
func Test_mysqlConnectionStringSocket(t *testing.T) {
c := &dbConfig{
Socket: "/var/run/mysql.sock",
Port: 9999,
User: "test_user",
Password: "test_password",
Name: "test_name",
Dialect: "mysql",
Charset: "utf8mb4",
MaxConn: 10,
Compress: true,
}
assert.Equal(t, fmt.Sprintf(
"%s:%s@unix(%s)/%s?charset=utf8mb4&parseTime=true&loc=%s&compress=true&sql_mode=ANSI_QUOTES",
c.User,
c.Password,
c.Socket,
c.Name,
"Local",
), mysqlConnectionString(c))
}
func Test_postgresConnectionString(t *testing.T) {
c := &dbConfig{
Host: "test_host",
Port: 9999,
User: "test_user",
Password: "test_password",
Name: "test_name",
Dialect: "postgres",
MaxConn: 10,
}
assert.Equal(t, fmt.Sprintf(
"host=%s port=%d user=%s dbname=%s password=%s sslmode=disable",
c.Host,
c.Port,
c.User,
c.Name,
c.Password,
), postgresConnectionString(c))
}
func Test_sqliteConnectionString(t *testing.T) {
c := &dbConfig{
Name: "test_name",
Dialect: "sqlite3",
}
assert.True(t, strings.HasPrefix(sqliteConnectionString(c), c.Name))
assert.Contains(t, strings.ToLower(sqliteConnectionString(c)), "journal_mode=wal")
}