3 Commits

2 changed files with 16 additions and 15 deletions

25
ca.go
View File

@@ -291,14 +291,9 @@ func parseValidity(validity string) (time.Duration, error) {
} }
func SavePEM(filename string, data []byte, secure bool) error { func SavePEM(filename string, data []byte, secure bool) error {
if !overwrite { if _, err := os.Stat(filename); err == nil {
if _, err := os.Stat(filename); err == nil { return fmt.Errorf("file %s already exists", filename)
return fmt.Errorf("file %s already exists (overwrite not allowed)", filename)
} else if !os.IsNotExist(err) {
return fmt.Errorf("could not check file %s: %v", filename, err)
}
} }
if secure { if secure {
return os.WriteFile(filename, data, 0600) return os.WriteFile(filename, data, 0600)
} else { } else {
@@ -477,8 +472,16 @@ func issueSingleCertificate(def CertificateDefinition) error {
} }
// Add default dns SAN for server/server-only if none specified // Add default dns SAN for server/server-only if none specified
if (def.Type == "server" || def.Type == "server-only") && len(def.SAN) == 0 { if strings.Contains(def.Type, "server") && len(def.SAN) == 0 {
def.SAN = append(def.SAN, "dns:"+def.Subject) // Extract CN if subject is a DN, else use subject as is
cn := def.Subject
if isDNFormat(def.Subject) {
dn := parseDistinguishedName(def.Subject)
if dn.CommonName != "" {
cn = dn.CommonName
}
}
def.SAN = append(def.SAN, "dns:"+cn)
} }
priv, err := rsa.GenerateKey(rand.Reader, 4096) priv, err := rsa.GenerateKey(rand.Reader, 4096)
@@ -600,7 +603,7 @@ Certificate:
} }
// A prototype of certificate provisioning function // A prototype of certificate provisioning function
func ProvisionCertificates(filePath string, overwrite bool, dryRun bool, verbose bool) error { func ProvisionCertificates(filePath string) error {
err := LoadCA() err := LoadCA()
if err != nil { if err != nil {
@@ -677,7 +680,7 @@ func ProvisionCertificates(filePath string, overwrite bool, dryRun bool, verbose
return nil return nil
} }
func IssueCertificate(certDef CertificateDefinition, overwrite bool, dryRun bool, verbose bool) error { func IssueCertificate(certDef CertificateDefinition) error {
err := LoadCA() err := LoadCA()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)

View File

@@ -8,7 +8,6 @@ import (
) )
// Global flags available to all commands // Global flags available to all commands
var overwrite bool
var dryRun bool var dryRun bool
var verbose bool var verbose bool
@@ -46,7 +45,6 @@ func main() {
} }
// Define persistent flags (global for all commands) // Define persistent flags (global for all commands)
rootCmd.PersistentFlags().BoolVar(&overwrite, "overwrite", false, "Allow overwriting existing files")
rootCmd.PersistentFlags().BoolVar(&verbose, "verbose", false, "Print detailed information about each processed certificate") rootCmd.PersistentFlags().BoolVar(&verbose, "verbose", false, "Print detailed information about each processed certificate")
rootCmd.PersistentFlags().BoolVar(&dryRun, "dry-run", false, "Validate and show what would be created, but do not write files (batch mode)") rootCmd.PersistentFlags().BoolVar(&dryRun, "dry-run", false, "Validate and show what would be created, but do not write files (batch mode)")
rootCmd.PersistentFlags().StringVar(&caConfigPath, "config", "ca_config.hcl", "Path to CA configuration file") rootCmd.PersistentFlags().StringVar(&caConfigPath, "config", "ca_config.hcl", "Path to CA configuration file")
@@ -95,7 +93,7 @@ func main() {
Type: certType, Type: certType,
Validity: validity, Validity: validity,
SAN: san, SAN: san,
}, overwrite, dryRun, verbose) })
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)
@@ -120,7 +118,7 @@ func main() {
Short: "Provision certificates from a batch file (HCL)", Short: "Provision certificates from a batch file (HCL)",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
err := ProvisionCertificates(provisionFile, overwrite, false, verbose) err := ProvisionCertificates(provisionFile)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)