aboutsummaryrefslogtreecommitdiff
path: root/internal/environment/flags.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/environment/flags.go')
-rw-r--r--internal/environment/flags.go196
1 files changed, 181 insertions, 15 deletions
diff --git a/internal/environment/flags.go b/internal/environment/flags.go
index bbb7d1a..09c243e 100644
--- a/internal/environment/flags.go
+++ b/internal/environment/flags.go
@@ -15,27 +15,193 @@
package environment
import (
+ "flag"
"fmt"
"os"
-
- "github.com/namsral/flag"
+ "strconv"
+ "strings"
)
-func (env *Environment) setFlags() {
- flag.StringVar(&env.ConfigFile, "config", "", "My config file")
- flag.StringVar(&env.BindAddr, "bind-addr", "localhost:8081", "The address where I'm going to listen")
- flag.StringVar(&env.BaseURL, "base-url", "", "The base shoelaces URL. If it's not defined, it will default to bind-addr.")
- flag.StringVar(&env.DataDir, "data-dir", "", "Directory with mappings, configs, templates, etc.")
- flag.StringVar(&env.StaticDir, "static-dir", "web", "A custom web directory with static files")
- flag.StringVar(&env.EnvDir, "env-dir", "env_overrides", "Directory with overrides")
- flag.StringVar(&env.TemplateExtension, "template-extension", ".slc", "Shoelaces template extension")
- flag.StringVar(&env.MappingsFile, "mappings-file", "mappings.yaml", "My mappings YAML file")
- flag.BoolVar(&env.Debug, "debug", false, "Debug mode")
+func (env *Environment) setFlags(args []string, environ []string) (*flag.FlagSet, error) {
+ env.setFlagDefaults()
+
+ configFile := configFileFromArgs(args)
+ if configFile == "" {
+ configFile = envValue(environ, "CONFIG")
+ }
+ if configFile != "" {
+ env.ConfigFile = configFile
+ if err := env.loadConfigFile(configFile); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := env.applyEnvVars(environ); err != nil {
+ return nil, err
+ }
+
+ flags := env.registerFlags()
+ if err := flags.Parse(args); err != nil {
+ return flags, err
+ }
+ return flags, nil
+}
+
+func (env *Environment) setFlagDefaults() {
+ env.BindAddr = "localhost:8081"
+ env.StaticDir = "web"
+ env.EnvDir = "env_overrides"
+ env.TemplateExtension = ".slc"
+ env.MappingsFile = "mappings.yaml"
+}
+
+func (env *Environment) registerFlags() *flag.FlagSet {
+ flags := flag.NewFlagSet("shoelaces", flag.ContinueOnError)
+ flags.StringVar(&env.ConfigFile, "config", env.ConfigFile, "My config file")
+ flags.StringVar(&env.BindAddr, "bind-addr", env.BindAddr, "The address where I'm going to listen")
+ flags.StringVar(&env.BaseURL, "base-url", env.BaseURL, "The base shoelaces URL. If it's not defined, it will default to bind-addr.")
+ flags.StringVar(&env.DataDir, "data-dir", env.DataDir, "Directory with mappings, configs, templates, etc.")
+ flags.StringVar(&env.StaticDir, "static-dir", env.StaticDir, "A custom web directory with static files")
+ flags.StringVar(&env.EnvDir, "env-dir", env.EnvDir, "Directory with overrides")
+ flags.StringVar(&env.TemplateExtension, "template-extension", env.TemplateExtension, "Shoelaces template extension")
+ flags.StringVar(&env.MappingsFile, "mappings-file", env.MappingsFile, "My mappings YAML file")
+ flags.BoolVar(&env.Debug, "debug", env.Debug, "Debug mode")
+ return flags
+}
+
+func configFileFromArgs(args []string) string {
+ for i, arg := range args {
+ if arg == "-config" || arg == "--config" {
+ if i+1 < len(args) {
+ return args[i+1]
+ }
+ return ""
+ }
+ if strings.HasPrefix(arg, "-config=") {
+ return strings.TrimPrefix(arg, "-config=")
+ }
+ if strings.HasPrefix(arg, "--config=") {
+ return strings.TrimPrefix(arg, "--config=")
+ }
+ }
+ return ""
+}
+
+func (env *Environment) loadConfigFile(configFile string) error {
+ contents, err := os.ReadFile(configFile)
+ if err != nil {
+ return err
+ }
- flag.Parse()
+ for _, line := range strings.Split(string(contents), "\n") {
+ key, value, ok := parseConfigLine(line)
+ if !ok {
+ continue
+ }
+ if err := env.setConfigValue(key, value); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func parseConfigLine(line string) (string, string, bool) {
+ line = strings.TrimSpace(line)
+ if line == "" || strings.HasPrefix(line, "#") {
+ return "", "", false
+ }
+
+ if key, value, found := strings.Cut(line, "="); found {
+ return strings.TrimSpace(key), strings.TrimSpace(value), true
+ }
+
+ fields := strings.Fields(line)
+ if len(fields) == 0 {
+ return "", "", false
+ }
+ if len(fields) == 1 {
+ return fields[0], "true", true
+ }
+ return fields[0], strings.Join(fields[1:], " "), true
+}
+
+func (env *Environment) applyEnvVars(environ []string) error {
+ if err := env.applyEnvVar(environ, "config", "CONFIG"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "bind-addr", "BIND_ADDR"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "base-url", "BASE_URL"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "data-dir", "DATA_DIR"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "static-dir", "STATIC_DIR"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "env-dir", "ENV_DIR"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "template-extension", "TEMPLATE_EXTENSION"); err != nil {
+ return err
+ }
+ if err := env.applyEnvVar(environ, "mappings-file", "MAPPINGS_FILE"); err != nil {
+ return err
+ }
+ return env.applyEnvVar(environ, "debug", "DEBUG")
+}
+
+func (env *Environment) applyEnvVar(environ []string, key, name string) error {
+ return env.setConfigValue(key, envValue(environ, name))
+}
+
+func envValue(environ []string, name string) string {
+ for _, entry := range environ {
+ key, value, found := strings.Cut(entry, "=")
+ if found && key == name {
+ return value
+ }
+ }
+ return ""
+}
+
+func (env *Environment) setConfigValue(key, value string) error {
+ if value == "" {
+ return nil
+ }
+
+ switch key {
+ case "config":
+ env.ConfigFile = value
+ case "bind-addr":
+ env.BindAddr = value
+ case "base-url":
+ env.BaseURL = value
+ case "data-dir":
+ env.DataDir = value
+ case "static-dir":
+ env.StaticDir = value
+ case "env-dir":
+ env.EnvDir = value
+ case "template-extension":
+ env.TemplateExtension = value
+ case "mappings-file":
+ env.MappingsFile = value
+ case "debug":
+ debug, err := strconv.ParseBool(value)
+ if err != nil {
+ return fmt.Errorf("invalid debug value %q: %w", value, err)
+ }
+ env.Debug = debug
+ default:
+ return fmt.Errorf("unknown config key %q", key)
+ }
+ return nil
}
-func (env *Environment) validateFlags() {
+func (env *Environment) validateFlags(flags *flag.FlagSet) {
error := false
if env.DataDir == "" {
@@ -50,7 +216,7 @@ func (env *Environment) validateFlags() {
if error {
fmt.Println("\nAvailable parameters:")
- flag.PrintDefaults()
+ flags.PrintDefaults()
fmt.Println("\nParameters can be specified as environment variables, arguments or in a config file.")
os.Exit(1)
}
nihil fit ex nihilo