aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/environment/environment.go8
-rw-r--r--internal/environment/flags.go196
-rw-r--r--internal/environment/flags_test.go133
3 files changed, 320 insertions, 17 deletions
diff --git a/internal/environment/environment.go b/internal/environment/environment.go
index 0bcb4a7..d0712d5 100644
--- a/internal/environment/environment.go
+++ b/internal/environment/environment.go
@@ -58,8 +58,12 @@ type Environment struct {
// New returns an initialized environment structure
func New() *Environment {
env := defaultEnvironment()
- env.setFlags()
- env.validateFlags()
+ flags, err := env.setFlags(os.Args[1:], os.Environ())
+ if err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ }
+ env.validateFlags(flags)
if env.Debug {
env.Logger = log.AllowDebug(env.Logger)
diff --git a/internal/environment/flags.go b/internal/environment/flags.go
index bbb7d1a..09c243e 100644
--- a/internal/environment/flags.go
+++ b/internal/environment/flags.go
@@ -15,27 +15,193 @@
package environment
import (
+ "flag"
"fmt"
"os"
-
- "github.com/namsral/flag"
+ "strconv"
+ "strings"
)
-func (env *Environment) setFlags() {
- flag.StringVar(&env.ConfigFile, "config", "", "My config file")
- flag.StringVar(&env.BindAddr, "bind-addr", "localhost:8081", "The address where I'm going to listen")
- flag.StringVar(&env.BaseURL, "base-url", "", "The base shoelaces URL. If it's not defined, it will default to bind-addr.")
- flag.StringVar(&env.DataDir, "data-dir", "", "Directory with mappings, configs, templates, etc.")
- flag.StringVar(&env.StaticDir, "static-dir", "web", "A custom web directory with static files")
- flag.StringVar(&env.EnvDir, "env-dir", "env_overrides", "Directory with overrides")
- flag.StringVar(&env.TemplateExtension, "template-extension", ".slc", "Shoelaces template extension")
- flag.StringVar(&env.MappingsFile, "mappings-file", "mappings.yaml", "My mappings YAML file")
- flag.BoolVar(&env.Debug, "debug", false, "Debug mode")
+func (env *Environment) setFlags(args []string, environ []string) (*flag.FlagSet, error) {
+ env.setFlagDefaults()
+
+ configFile := configFileFromArgs(args)
+ if configFile == "" {
+ configFile = envValue(environ, "CONFIG")
+ }
+ if configFile != "" {
+ env.ConfigFile = configFile
+ if err := env.loadConfigFile(configFile); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := env.applyEnvVars(environ); err != nil {
+ return nil, err
+ }
+
+ flags := env.registerFlags()
+ if err := flags.Parse(args); err != nil {
+ return flags, err
+ }
+ return flags, nil
+}
+
+func (env *Environment) setFlagDefaults() {
+ env.BindAddr = "localhost:8081"
+ env.StaticDir = "web"
+ env.EnvDir = "env_overrides"
+ env.TemplateExtension = ".slc"
+ env.MappingsFile = "mappings.yaml"
+}
+
+func (env *Environment) registerFlags() *flag.FlagSet {
+ flags := flag.NewFlagSet("shoelaces", flag.ContinueOnError)
+ flags.StringVar(&env.ConfigFile, "config", env.ConfigFile, "My config file")
+ flags.StringVar(&env.BindAddr, "bind-addr", env.BindAddr, "The address where I'm going to listen")
+ flags.StringVar(&env.BaseURL, "base-url", env.BaseURL, "The base shoelaces URL. If it's not defined, it will default to bind-addr.")
+ flags.StringVar(&env.DataDir, "data-dir", env.DataDir, "Directory with mappings, configs, templates, etc.")
+ flags.StringVar(&env.StaticDir, "static-dir", env.StaticDir, "A custom web directory with static files")
+ flags.StringVar(&env.EnvDir, "env-dir", env.EnvDir, "Directory with overrides")
+ flags.StringVar(&env.TemplateExtension, "template-extension", env.TemplateExtension, "Shoelaces template extension")
+ flags.StringVar(&env.MappingsFile, "mappings-file", env.MappingsFile, "My mappings YAML file")
+ flags.BoolVar(&env.Debug, "debug", env.Debug, "Debug mode")
+ return flags
+}
+
+func configFileFromArgs(args []string) string {
+ for i, arg := range args {
+ if arg == "-config" || arg == "--config" {
+ if i+1 < len(args) {
+ return args[i+1]
+ }
+ return ""
+ }
+ if strings.HasPrefix(arg, "-config=") {
+ return strings.TrimPrefix(arg, "-config=")
+ }
+ if strings.HasPrefix(arg, "--config=") {
+ return strings.TrimPrefix(arg, "--config=")
+ }
+ }
+ return ""
+}
+
+func (env *Environment) loadConfigFile(configFile string) error {
+ contents, err := os.ReadFile(configFile)
+ if err != nil {
+ return err
+ }
- flag.Parse()
+ for _, line := range strings.Split(string(contents), "\n") {
+ key, value, ok := parseConfigLine(line)
+ if !ok {
+ continue
+ }
+ if err := env.setConfigValue(key, value); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func parseConfigLine(line string) (string, string, bool) {
+ line = strings.TrimSpace(line)
+ if line == "" || strings.HasPrefix(line, "#") {
+ return "", "", false
+ }
+
+ if key, value, found := strings.Cut(line, "="); found {
+ return strings.TrimSpace(key), strings.TrimSpace(value), true
+ }
+
+ fields := strings.Fields(line)
+ if len(fields) == 0 {
+ return "", "", false
+ }
+ if len(fields) == 1 {
+ return fields[0], "true", true
+ }
+ return fields[0], strings.Join(fields[1:], " "), true
+}
+
+func (env *Environment) applyEnvVars(environ []string) error {
+ if err := env.applyEnvVar(environ, "config", "CONFIG"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "bind-addr", "BIND_ADDR"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "base-url", "BASE_URL"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "data-dir", "DATA_DIR"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "static-dir", "STATIC_DIR"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "env-dir", "ENV_DIR"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "template-extension", "TEMPLATE_EXTENSION"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "mappings-file", "MAPPINGS_FILE"); err != nil {
+ return err
+ }
+ return env.applyEnvVar(environ, "debug", "DEBUG")
+}
+
+func (env *Environment) applyEnvVar(environ []string, key, name string) error {
+ return env.setConfigValue(key, envValue(environ, name))
+}
+
+func envValue(environ []string, name string) string {
+ for _, entry := range environ {
+ key, value, found := strings.Cut(entry, "=")
+ if found && key == name {
+ return value
+ }
+ }
+ return ""
+}
+
+func (env *Environment) setConfigValue(key, value string) error {
+ if value == "" {
+ return nil
+ }
+
+ switch key {
+ case "config":
+ env.ConfigFile = value
+ case "bind-addr":
+ env.BindAddr = value
+ case "base-url":
+ env.BaseURL = value
+ case "data-dir":
+ env.DataDir = value
+ case "static-dir":
+ env.StaticDir = value
+ case "env-dir":
+ env.EnvDir = value
+ case "template-extension":
+ env.TemplateExtension = value
+ case "mappings-file":
+ env.MappingsFile = value
+ case "debug":
+ debug, err := strconv.ParseBool(value)
+ if err != nil {
+ return fmt.Errorf("invalid debug value %q: %w", value, err)
+ }
+ env.Debug = debug
+ default:
+ return fmt.Errorf("unknown config key %q", key)
+ }
+ return nil
}
-func (env *Environment) validateFlags() {
+func (env *Environment) validateFlags(flags *flag.FlagSet) {
error := false
if env.DataDir == "" {
@@ -50,7 +216,7 @@ func (env *Environment) validateFlags() {
if error {
fmt.Println("\nAvailable parameters:")
- flag.PrintDefaults()
+ flags.PrintDefaults()
fmt.Println("\nParameters can be specified as environment variables, arguments or in a config file.")
os.Exit(1)
}
diff --git a/internal/environment/flags_test.go b/internal/environment/flags_test.go
new file mode 100644
index 0000000..0cab282
--- /dev/null
+++ b/internal/environment/flags_test.go
@@ -0,0 +1,133 @@
+// Copyright 2018 ThousandEyes Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package environment
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestSetFlagsAppliesDefaults(t *testing.T) {
+ env := defaultEnvironment()
+ if _, err := env.setFlags(nil, nil); err != nil {
+ t.Fatal(err)
+ }
+
+ if env.BindAddr != "localhost:8081" {
+ t.Errorf("Expected default bind address, got %q", env.BindAddr)
+ }
+ if env.StaticDir != "web" {
+ t.Errorf("Expected default static dir, got %q", env.StaticDir)
+ }
+ if env.EnvDir != "env_overrides" {
+ t.Errorf("Expected default env dir, got %q", env.EnvDir)
+ }
+ if env.TemplateExtension != ".slc" {
+ t.Errorf("Expected default template extension, got %q", env.TemplateExtension)
+ }
+ if env.MappingsFile != "mappings.yaml" {
+ t.Errorf("Expected default mappings file, got %q", env.MappingsFile)
+ }
+}
+
+func TestSetFlagsLoadsConfigEnvAndCLIInOrder(t *testing.T) {
+ configFile := writeConfig(t, "shoelaces.conf", ""+
+ "# comment\n"+
+ "bind-addr=config:1\n"+
+ "data-dir config-data\n"+
+ "static-dir=config-static\n"+
+ "debug\n")
+
+ env := defaultEnvironment()
+ args := []string{"-bind-addr", "cli:3", "-mappings-file", "cli.yaml"}
+ environ := []string{
+ "CONFIG=" + configFile,
+ "BIND_ADDR=env:2",
+ "DEBUG=false",
+ "STATIC_DIR=env-static",
+ }
+
+ if _, err := env.setFlags(args, environ); err != nil {
+ t.Fatal(err)
+ }
+
+ if env.ConfigFile != configFile {
+ t.Errorf("Expected config file %q, got %q", configFile, env.ConfigFile)
+ }
+ if env.BindAddr != "cli:3" {
+ t.Errorf("Expected CLI bind address, got %q", env.BindAddr)
+ }
+ if env.DataDir != "config-data" {
+ t.Errorf("Expected data dir from config, got %q", env.DataDir)
+ }
+ if env.StaticDir != "env-static" {
+ t.Errorf("Expected static dir from env, got %q", env.StaticDir)
+ }
+ if env.MappingsFile != "cli.yaml" {
+ t.Errorf("Expected mappings file from CLI, got %q", env.MappingsFile)
+ }
+ if env.Debug {
+ t.Error("Expected DEBUG env var to override bare debug config value")
+ }
+}
+
+func TestSetFlagsCLIConfigOverridesEnvConfig(t *testing.T) {
+ cliConfig := writeConfig(t, "cli.conf", "data-dir=cli-data\n")
+ envConfig := writeConfig(t, "env.conf", "data-dir=env-data\n")
+
+ env := defaultEnvironment()
+ if _, err := env.setFlags([]string{"-config", cliConfig}, []string{"CONFIG=" + envConfig}); err != nil {
+ t.Fatal(err)
+ }
+
+ if env.ConfigFile != cliConfig {
+ t.Errorf("Expected CLI config file %q, got %q", cliConfig, env.ConfigFile)
+ }
+ if env.DataDir != "cli-data" {
+ t.Errorf("Expected data dir from CLI-selected config, got %q", env.DataDir)
+ }
+}
+
+func TestSetFlagsReturnsConfigErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ config string
+ }{
+ {name: "unknown key", config: "unknown=value\n"},
+ {name: "invalid bool", config: "debug=maybe\n"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ env := defaultEnvironment()
+ configFile := writeConfig(t, "shoelaces.conf", tt.config)
+
+ if _, err := env.setFlags([]string{"-config", configFile}, nil); err == nil {
+ t.Fatal("Expected error")
+ }
+ })
+ }
+}
+
+func writeConfig(t *testing.T, name string, contents string) string {
+ t.Helper()
+
+ configFile := filepath.Join(t.TempDir(), name)
+ if err := os.WriteFile(configFile, []byte(contents), 0600); err != nil {
+ t.Fatal(err)
+ }
+ return configFile
+}
nihil fit ex nihilo