diff options
| author | Raúl Benencia <id@rbenencia.name> | 2026-06-05 15:29:31 -0300 |
|---|---|---|
| committer | Raul Benencia <46945030+raul-te@users.noreply.github.com> | 2026-06-05 16:29:33 -0300 |
| commit | c22c58b9bb67a24531be4e20691f4ed5716db649 (patch) | |
| tree | 1c884e4aa3ea72d2e3a0575d1f0ab8a02c1261d6 /internal | |
| parent | f4631375414422d87f0d16579fd3101fca3c2289 (diff) | |
Use stdlib flags
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/environment/environment.go | 8 | ||||
| -rw-r--r-- | internal/environment/flags.go | 196 | ||||
| -rw-r--r-- | internal/environment/flags_test.go | 133 |
3 files changed, 320 insertions, 17 deletions
diff --git a/internal/environment/environment.go b/internal/environment/environment.go index 0bcb4a7..d0712d5 100644 --- a/internal/environment/environment.go +++ b/internal/environment/environment.go @@ -58,8 +58,12 @@ type Environment struct { // New returns an initialized environment structure func New() *Environment { env := defaultEnvironment() - env.setFlags() - env.validateFlags() + flags, err := env.setFlags(os.Args[1:], os.Environ()) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + env.validateFlags(flags) if env.Debug { env.Logger = log.AllowDebug(env.Logger) diff --git a/internal/environment/flags.go b/internal/environment/flags.go index bbb7d1a..09c243e 100644 --- a/internal/environment/flags.go +++ b/internal/environment/flags.go @@ -15,27 +15,193 @@ package environment import ( + "flag" "fmt" "os" - - "github.com/namsral/flag" + "strconv" + "strings" ) -func (env *Environment) setFlags() { - flag.StringVar(&env.ConfigFile, "config", "", "My config file") - flag.StringVar(&env.BindAddr, "bind-addr", "localhost:8081", "The address where I'm going to listen") - flag.StringVar(&env.BaseURL, "base-url", "", "The base shoelaces URL. If it's not defined, it will default to bind-addr.") - flag.StringVar(&env.DataDir, "data-dir", "", "Directory with mappings, configs, templates, etc.") - flag.StringVar(&env.StaticDir, "static-dir", "web", "A custom web directory with static files") - flag.StringVar(&env.EnvDir, "env-dir", "env_overrides", "Directory with overrides") - flag.StringVar(&env.TemplateExtension, "template-extension", ".slc", "Shoelaces template extension") - flag.StringVar(&env.MappingsFile, "mappings-file", "mappings.yaml", "My mappings YAML file") - flag.BoolVar(&env.Debug, "debug", false, "Debug mode") +func (env *Environment) setFlags(args []string, environ []string) (*flag.FlagSet, error) { + env.setFlagDefaults() + + configFile := configFileFromArgs(args) + if configFile == "" { + configFile = envValue(environ, "CONFIG") + } + if configFile != "" { + env.ConfigFile = configFile + if err := env.loadConfigFile(configFile); err != nil { + return nil, err + } + } + + if err := env.applyEnvVars(environ); err != nil { + return nil, err + } + + flags := env.registerFlags() + if err := flags.Parse(args); err != nil { + return flags, err + } + return flags, nil +} + +func (env *Environment) setFlagDefaults() { + env.BindAddr = "localhost:8081" + env.StaticDir = "web" + env.EnvDir = "env_overrides" + env.TemplateExtension = ".slc" + env.MappingsFile = "mappings.yaml" +} + +func (env *Environment) registerFlags() *flag.FlagSet { + flags := flag.NewFlagSet("shoelaces", flag.ContinueOnError) + flags.StringVar(&env.ConfigFile, "config", env.ConfigFile, "My config file") + flags.StringVar(&env.BindAddr, "bind-addr", env.BindAddr, "The address where I'm going to listen") + flags.StringVar(&env.BaseURL, "base-url", env.BaseURL, "The base shoelaces URL. If it's not defined, it will default to bind-addr.") + flags.StringVar(&env.DataDir, "data-dir", env.DataDir, "Directory with mappings, configs, templates, etc.") + flags.StringVar(&env.StaticDir, "static-dir", env.StaticDir, "A custom web directory with static files") + flags.StringVar(&env.EnvDir, "env-dir", env.EnvDir, "Directory with overrides") + flags.StringVar(&env.TemplateExtension, "template-extension", env.TemplateExtension, "Shoelaces template extension") + flags.StringVar(&env.MappingsFile, "mappings-file", env.MappingsFile, "My mappings YAML file") + flags.BoolVar(&env.Debug, "debug", env.Debug, "Debug mode") + return flags +} + +func configFileFromArgs(args []string) string { + for i, arg := range args { + if arg == "-config" || arg == "--config" { + if i+1 < len(args) { + return args[i+1] + } + return "" + } + if strings.HasPrefix(arg, "-config=") { + return strings.TrimPrefix(arg, "-config=") + } + if strings.HasPrefix(arg, "--config=") { + return strings.TrimPrefix(arg, "--config=") + } + } + return "" +} + +func (env *Environment) loadConfigFile(configFile string) error { + contents, err := os.ReadFile(configFile) + if err != nil { + return err + } - flag.Parse() + for _, line := range strings.Split(string(contents), "\n") { + key, value, ok := parseConfigLine(line) + if !ok { + continue + } + if err := env.setConfigValue(key, value); err != nil { + return err + } + } + return nil +} + +func parseConfigLine(line string) (string, string, bool) { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + return "", "", false + } + + if key, value, found := strings.Cut(line, "="); found { + return strings.TrimSpace(key), strings.TrimSpace(value), true + } + + fields := strings.Fields(line) + if len(fields) == 0 { + return "", "", false + } + if len(fields) == 1 { + return fields[0], "true", true + } + return fields[0], strings.Join(fields[1:], " "), true +} + +func (env *Environment) applyEnvVars(environ []string) error { + if err := env.applyEnvVar(environ, "config", "CONFIG"); err != nil { + return err + } + if err := env.applyEnvVar(environ, "bind-addr", "BIND_ADDR"); err != nil { + return err + } + if err := env.applyEnvVar(environ, "base-url", "BASE_URL"); err != nil { + return err + } + if err := env.applyEnvVar(environ, "data-dir", "DATA_DIR"); err != nil { + return err + } + if err := env.applyEnvVar(environ, "static-dir", "STATIC_DIR"); err != nil { + return err + } + if err := env.applyEnvVar(environ, "env-dir", "ENV_DIR"); err != nil { + return err + } + if err := env.applyEnvVar(environ, "template-extension", "TEMPLATE_EXTENSION"); err != nil { + return err + } + if err := env.applyEnvVar(environ, "mappings-file", "MAPPINGS_FILE"); err != nil { + return err + } + return env.applyEnvVar(environ, "debug", "DEBUG") +} + +func (env *Environment) applyEnvVar(environ []string, key, name string) error { + return env.setConfigValue(key, envValue(environ, name)) +} + +func envValue(environ []string, name string) string { + for _, entry := range environ { + key, value, found := strings.Cut(entry, "=") + if found && key == name { + return value + } + } + return "" +} + +func (env *Environment) setConfigValue(key, value string) error { + if value == "" { + return nil + } + + switch key { + case "config": + env.ConfigFile = value + case "bind-addr": + env.BindAddr = value + case "base-url": + env.BaseURL = value + case "data-dir": + env.DataDir = value + case "static-dir": + env.StaticDir = value + case "env-dir": + env.EnvDir = value + case "template-extension": + env.TemplateExtension = value + case "mappings-file": + env.MappingsFile = value + case "debug": + debug, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid debug value %q: %w", value, err) + } + env.Debug = debug + default: + return fmt.Errorf("unknown config key %q", key) + } + return nil } -func (env *Environment) validateFlags() { +func (env *Environment) validateFlags(flags *flag.FlagSet) { error := false if env.DataDir == "" { @@ -50,7 +216,7 @@ func (env *Environment) validateFlags() { if error { fmt.Println("\nAvailable parameters:") - flag.PrintDefaults() + flags.PrintDefaults() fmt.Println("\nParameters can be specified as environment variables, arguments or in a config file.") os.Exit(1) } diff --git a/internal/environment/flags_test.go b/internal/environment/flags_test.go new file mode 100644 index 0000000..0cab282 --- /dev/null +++ b/internal/environment/flags_test.go @@ -0,0 +1,133 @@ +// Copyright 2018 ThousandEyes Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package environment + +import ( + "os" + "path/filepath" + "testing" +) + +func TestSetFlagsAppliesDefaults(t *testing.T) { + env := defaultEnvironment() + if _, err := env.setFlags(nil, nil); err != nil { + t.Fatal(err) + } + + if env.BindAddr != "localhost:8081" { + t.Errorf("Expected default bind address, got %q", env.BindAddr) + } + if env.StaticDir != "web" { + t.Errorf("Expected default static dir, got %q", env.StaticDir) + } + if env.EnvDir != "env_overrides" { + t.Errorf("Expected default env dir, got %q", env.EnvDir) + } + if env.TemplateExtension != ".slc" { + t.Errorf("Expected default template extension, got %q", env.TemplateExtension) + } + if env.MappingsFile != "mappings.yaml" { + t.Errorf("Expected default mappings file, got %q", env.MappingsFile) + } +} + +func TestSetFlagsLoadsConfigEnvAndCLIInOrder(t *testing.T) { + configFile := writeConfig(t, "shoelaces.conf", ""+ + "# comment\n"+ + "bind-addr=config:1\n"+ + "data-dir config-data\n"+ + "static-dir=config-static\n"+ + "debug\n") + + env := defaultEnvironment() + args := []string{"-bind-addr", "cli:3", "-mappings-file", "cli.yaml"} + environ := []string{ + "CONFIG=" + configFile, + "BIND_ADDR=env:2", + "DEBUG=false", + "STATIC_DIR=env-static", + } + + if _, err := env.setFlags(args, environ); err != nil { + t.Fatal(err) + } + + if env.ConfigFile != configFile { + t.Errorf("Expected config file %q, got %q", configFile, env.ConfigFile) + } + if env.BindAddr != "cli:3" { + t.Errorf("Expected CLI bind address, got %q", env.BindAddr) + } + if env.DataDir != "config-data" { + t.Errorf("Expected data dir from config, got %q", env.DataDir) + } + if env.StaticDir != "env-static" { + t.Errorf("Expected static dir from env, got %q", env.StaticDir) + } + if env.MappingsFile != "cli.yaml" { + t.Errorf("Expected mappings file from CLI, got %q", env.MappingsFile) + } + if env.Debug { + t.Error("Expected DEBUG env var to override bare debug config value") + } +} + +func TestSetFlagsCLIConfigOverridesEnvConfig(t *testing.T) { + cliConfig := writeConfig(t, "cli.conf", "data-dir=cli-data\n") + envConfig := writeConfig(t, "env.conf", "data-dir=env-data\n") + + env := defaultEnvironment() + if _, err := env.setFlags([]string{"-config", cliConfig}, []string{"CONFIG=" + envConfig}); err != nil { + t.Fatal(err) + } + + if env.ConfigFile != cliConfig { + t.Errorf("Expected CLI config file %q, got %q", cliConfig, env.ConfigFile) + } + if env.DataDir != "cli-data" { + t.Errorf("Expected data dir from CLI-selected config, got %q", env.DataDir) + } +} + +func TestSetFlagsReturnsConfigErrors(t *testing.T) { + tests := []struct { + name string + config string + }{ + {name: "unknown key", config: "unknown=value\n"}, + {name: "invalid bool", config: "debug=maybe\n"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + env := defaultEnvironment() + configFile := writeConfig(t, "shoelaces.conf", tt.config) + + if _, err := env.setFlags([]string{"-config", configFile}, nil); err == nil { + t.Fatal("Expected error") + } + }) + } +} + +func writeConfig(t *testing.T, name string, contents string) string { + t.Helper() + + configFile := filepath.Join(t.TempDir(), name) + if err := os.WriteFile(configFile, []byte(contents), 0600); err != nil { + t.Fatal(err) + } + return configFile +} |
