From e2039550e0ca7b3d561f3746c1484df90222daf3 Mon Sep 17 00:00:00 2001 From: Slawek Koszewski Date: Sun, 27 Jul 2025 20:45:05 +0200 Subject: [PATCH] CAConfig global variable and refactoring. --- ca.go | 553 +++++++++++++++++++----------------------------------- certdb.go | 34 ++-- main.go | 11 +- 3 files changed, 224 insertions(+), 374 deletions(-) diff --git a/ca.go b/ca.go index 890888f..9e1c8d0 100644 --- a/ca.go +++ b/ca.go @@ -25,7 +25,7 @@ type Paths struct { PrivateKeys string `hcl:"private_keys"` } -type CAConfig struct { +type _CAConfig struct { Label string `hcl:",label"` Name string `hcl:"name"` Country string `hcl:"country"` @@ -41,7 +41,7 @@ type CAConfig struct { } type Configuration struct { - Current CAConfig `hcl:"ca,block"` + Current _CAConfig `hcl:"ca,block"` } type CertificateDefinition struct { @@ -64,32 +64,42 @@ type Certificates struct { Certificates []CertificateDefinition `hcl:"certificate,block"` } -func LoadCA(path string) (*CAConfig, error) { +// Global CA configurationa and state variables +var CAState *_CAState +var CAConfig *_CAConfig + +// LoadCA loads the CA config and sets the global CAConfig variable +func LoadCA(path string) error { parser := hclparse.NewParser() file, diags := parser.ParseHCLFile(path) if diags.HasErrors() { - return nil, fmt.Errorf("failed to parse HCL: %s", diags.Error()) + return fmt.Errorf("failed to parse HCL: %s", diags.Error()) } var config Configuration diags = gohcl.DecodeBody(file.Body, nil, &config) if diags.HasErrors() { - return nil, fmt.Errorf("failed to decode HCL: %s", diags.Error()) + return fmt.Errorf("failed to decode HCL: %s", diags.Error()) } - if (CAConfig{}) == config.Current { - return nil, fmt.Errorf("no 'ca' block found in config file") + if (_CAConfig{}) == config.Current { + return fmt.Errorf("no 'ca' block found in config file") } if config.Current.Label == "" { - return nil, fmt.Errorf("the 'ca' block must have a label (e.g., ca \"mylabel\" {...})") + return fmt.Errorf("the 'ca' block must have a label (e.g., ca \"mylabel\" {...})") } if err := config.Current.Validate(); err != nil { - return nil, err + return err } + CAConfig = &config.Current // Derive caStatePath from caConfig label and config file path caDir := filepath.Dir(path) caLabel := config.Current.Label caStatePath := filepath.Join(caDir, caLabel+"_state.json") - GlobalCAState, _ = LoadCAState(caStatePath) - return &config.Current, nil + err := error(nil) + CAState, err = LoadCAState(caStatePath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to load CA state: %w", err) + } + return nil } // Parse certificates.hcl file with defaults support @@ -134,8 +144,9 @@ func parseValidity(validity string) (time.Duration, error) { } } -func GenerateCA(ca *CAConfig) ([]byte, []byte, error) { - keySize := ca.KeySize +func GenerateCA() ([]byte, []byte, error) { + // Use global CAConfig directly + keySize := CAConfig.KeySize if keySize == 0 { keySize = 4096 } @@ -148,7 +159,7 @@ func GenerateCA(ca *CAConfig) ([]byte, []byte, error) { if err != nil { return nil, nil, fmt.Errorf("failed to generate serial number: %v", err) } - validity, err := parseValidity(ca.Validity) + validity, err := parseValidity(CAConfig.Validity) if err != nil { return nil, nil, err } @@ -156,12 +167,12 @@ func GenerateCA(ca *CAConfig) ([]byte, []byte, error) { tmpl := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ - Country: []string{ca.Country}, - Organization: []string{ca.Organization}, - OrganizationalUnit: optionalSlice(ca.OrganizationalUnit), - Locality: optionalSlice(ca.Locality), - Province: optionalSlice(ca.Province), - CommonName: ca.Name, + Country: []string{CAConfig.Country}, + Organization: []string{CAConfig.Organization}, + OrganizationalUnit: optionalSlice(CAConfig.OrganizationalUnit), + Locality: optionalSlice(CAConfig.Locality), + Province: optionalSlice(CAConfig.Province), + CommonName: CAConfig.Name, }, NotBefore: now, NotAfter: now.Add(validity), @@ -170,10 +181,10 @@ func GenerateCA(ca *CAConfig) ([]byte, []byte, error) { IsCA: true, } // Add email if present - if ca.Email != "" { + if CAConfig.Email != "" { tmpl.Subject.ExtraNames = append(tmpl.Subject.ExtraNames, pkix.AttributeTypeAndValue{ Type: []int{1, 2, 840, 113549, 1, 9, 1}, // emailAddress OID - Value: ca.Email, + Value: CAConfig.Email, }) } certDER, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) @@ -211,7 +222,7 @@ func (p *Paths) Validate() error { return nil } -func (c *CAConfig) Validate() error { +func (c *_CAConfig) Validate() error { if c.Name == "" { return fmt.Errorf("CA 'name' is required") } @@ -236,49 +247,179 @@ func (c *CAConfig) Validate() error { } func InitCA(configPath string, overwrite bool) { - ca, err := LoadCA(configPath) - if err != nil { + if err := LoadCA(configPath); err != nil { fmt.Println("Error loading config:", err) return } - // Create certificates directory with 0755, private keys with 0700 - if ca.Paths.Certificates != "" { - if err := os.MkdirAll(ca.Paths.Certificates, 0755); err != nil { - fmt.Printf("Error creating certificates directory '%s': %v\n", ca.Paths.Certificates, err) + if CAConfig.Paths.Certificates != "" { + if err := os.MkdirAll(CAConfig.Paths.Certificates, 0755); err != nil { + fmt.Printf("Error creating certificates directory '%s': %v\n", CAConfig.Paths.Certificates, err) return } } - if ca.Paths.PrivateKeys != "" { - if err := os.MkdirAll(ca.Paths.PrivateKeys, 0700); err != nil { - fmt.Printf("Error creating private keys directory '%s': %v\n", ca.Paths.PrivateKeys, err) + if CAConfig.Paths.PrivateKeys != "" { + if err := os.MkdirAll(CAConfig.Paths.PrivateKeys, 0700); err != nil { + fmt.Printf("Error creating private keys directory '%s': %v\n", CAConfig.Paths.PrivateKeys, err) return } } - - certPEM, keyPEM, err := GenerateCA(ca) + certPEM, keyPEM, err := GenerateCA() if err != nil { fmt.Println("Error generating CA:", err) return } - if err := SavePEM(filepath.Join(ca.Paths.Certificates, "ca_cert.pem"), certPEM, false, overwrite); err != nil { + if err := SavePEM(filepath.Join(CAConfig.Paths.Certificates, "ca_cert.pem"), certPEM, false, overwrite); err != nil { fmt.Println("Error saving CA certificate:", err) return } - if err := SavePEM(filepath.Join(ca.Paths.PrivateKeys, "ca_key.pem"), keyPEM, true, overwrite); err != nil { + if err := SavePEM(filepath.Join(CAConfig.Paths.PrivateKeys, "ca_key.pem"), keyPEM, true, overwrite); err != nil { fmt.Println("Error saving CA key:", err) return } fmt.Println("CA certificate and key generated.") } -func IssueCertificate(configPath, subject, certType, validity string, san []string, name, fromFile string, overwrite, dryRun, verbose bool) { - // Load CA config to get label for state file - ca, err := LoadCA(configPath) - if err != nil { - fmt.Printf("Error loading CA config: %v\n", err) - return +// Helper: issue a single certificate and key, save to files, return error if any +func issueSingleCertificate(def CertificateDefinition, overwrite, verbose bool) error { + // Use global CAConfig directly + // Add default dns SAN for server/server-only if none specified + if (def.Type == "server" || def.Type == "server-only") && len(def.SAN) == 0 { + def.SAN = append(def.SAN, "dns:"+def.Subject) } + + caCertPath := filepath.Join(CAConfig.Paths.Certificates, "ca_cert.pem") + caKeyPath := filepath.Join(CAConfig.Paths.PrivateKeys, "ca_key.pem") + + caCertPEM, err := os.ReadFile(caCertPath) + if err != nil { + return fmt.Errorf("error reading CA certificate file: %v", err) + } + caKeyPEM, err := os.ReadFile(caKeyPath) + if err != nil { + return fmt.Errorf("error reading CA key file: %v", err) + } + + caCertBlock, _ := pem.Decode(caCertPEM) + if caCertBlock == nil { + return fmt.Errorf("failed to parse CA certificate PEM") + } + caCert, err := x509.ParseCertificate(caCertBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse CA certificate: %v", err) + } + caKeyBlock, _ := pem.Decode(caKeyPEM) + if caKeyBlock == nil { + return fmt.Errorf("failed to parse CA key PEM") + } + caKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse CA private key: %v", err) + } + + priv, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return fmt.Errorf("failed to generate private key: %v", err) + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return fmt.Errorf("failed to generate serial number: %v", err) + } + + var validityDur time.Duration + if def.Validity != "" { + validityDur, err = parseValidity(def.Validity) + if err != nil { + return fmt.Errorf("invalid validity value: %v", err) + } + } else { + validityDur = 365 * 24 * time.Hour // default 1 year + } + + var subjectPKIX pkix.Name + if isDNFormat(def.Subject) { + subjectPKIX = parseDistinguishedName(def.Subject) + } else { + subjectPKIX = pkix.Name{CommonName: def.Subject} + } + + certTmpl := x509.Certificate{ + SerialNumber: serialNumber, + Subject: subjectPKIX, + NotBefore: time.Now(), + NotAfter: time.Now().Add(validityDur), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + } + + for _, s := range def.SAN { + sLower := strings.ToLower(s) + var val string + if n, _ := fmt.Sscanf(sLower, "dns:%s", &val); n == 1 { + certTmpl.DNSNames = append(certTmpl.DNSNames, val) + } else if n, _ := fmt.Sscanf(sLower, "ip:%s", &val); n == 1 { + certTmpl.IPAddresses = append(certTmpl.IPAddresses, net.ParseIP(val)) + } else if n, _ := fmt.Sscanf(sLower, "email:%s", &val); n == 1 { + certTmpl.EmailAddresses = append(certTmpl.EmailAddresses, val) + } else { + return fmt.Errorf("invalid SAN format: %s", s) + } + } + + switch def.Type { + case "client": + certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} + case "server": + certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth} + case "server-only": + certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + case "code-signing": + certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning} + case "email": + certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageEmailProtection} + default: + return fmt.Errorf("unknown certificate type. Use one of: client, server, server-only, code-signing, email") + } + + certDER, err := x509.CreateCertificate(rand.Reader, &certTmpl, caCert, &priv.PublicKey, caKey) + if err != nil { + return fmt.Errorf("failed to create certificate: %v", err) + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + basename := def.Name + if basename == "" { + basename = def.Subject + } + certFile := filepath.Join(CAConfig.Paths.Certificates, basename+"."+def.Type+".crt.pem") + keyFile := filepath.Join(CAConfig.Paths.PrivateKeys, basename+"."+def.Type+".key.pem") + if err := SavePEM(certFile, certPEM, false, overwrite); err != nil { + return fmt.Errorf("error saving certificate: %v", err) + } + if err := SavePEM(keyFile, keyPEM, true, overwrite); err != nil { + return fmt.Errorf("error saving key: %v", err) + } + if verbose { + fmt.Printf(` +Certificate: + Name: %s + Subject: %s + Type: %s + Validity: %s + SAN: %v +`, + def.Name, + def.Subject, + def.Type, + def.Validity, + def.SAN, + ) + } + return nil +} + +func IssueCertificate(configPath, subject, certType, validity string, san []string, name, fromFile string, overwrite, dryRun, verbose bool) { if fromFile != "" { certDefs, defaults, err := LoadCertificatesFile(fromFile) if err != nil { @@ -287,7 +428,6 @@ func IssueCertificate(configPath, subject, certType, validity string, san []stri } successes := 0 errors := 0 - var basename, certFile, keyFile string // Declare variables before the loop for i, def := range certDefs { if defaults != nil { if def.Type == "" { @@ -301,347 +441,50 @@ func IssueCertificate(configPath, subject, certType, validity string, san []stri } } finalDef := renderCertificateDefTemplates(def, defaults) - fmt.Printf("[%d/%d] Issuing %s... ", i+1, len(certDefs), finalDef.Name) - if dryRun { fmt.Printf("(dry run)\n") - } - - if verbose { - fmt.Printf("\n Name: %s\n", finalDef.Name) - fmt.Printf(" Subject: %s\n", finalDef.Subject) - fmt.Printf(" Type: %s\n", finalDef.Type) - fmt.Printf(" Validity: %s\n", finalDef.Validity) - fmt.Printf(" SAN: %v\n\n", finalDef.SAN) - } - - basename = finalDef.Name - if basename == "" { - basename = finalDef.Subject - } - certFile = filepath.Join(ca.Paths.Certificates, basename+"."+finalDef.Type+".crt.pem") - keyFile = filepath.Join(ca.Paths.PrivateKeys, basename+"."+finalDef.Type+".key.pem") - - if dryRun { successes++ continue } - - // Inline certificate issuance logic for batch mode - // Add default dns SAN for server/server-only if none specified - if (finalDef.Type == "server" || finalDef.Type == "server-only") && len(finalDef.SAN) == 0 { - finalDef.SAN = append(finalDef.SAN, "dns:"+finalDef.Subject) - } - - caCertPath := filepath.Join(ca.Paths.Certificates, "ca_cert.pem") - caKeyPath := filepath.Join(ca.Paths.PrivateKeys, "ca_key.pem") - - caCertPEM, err := os.ReadFile(caCertPath) + err := issueSingleCertificate(finalDef, overwrite, verbose) if err != nil { - fmt.Println("Error reading CA certificate file:", err) + fmt.Printf("error: %v\n", err) errors++ - continue - } - caKeyPEM, err := os.ReadFile(caKeyPath) - if err != nil { - fmt.Println("Error reading CA key file:", err) - errors++ - continue - } - - caCertBlock, _ := pem.Decode(caCertPEM) - if caCertBlock == nil { - fmt.Println("Failed to parse CA certificate PEM") - errors++ - continue - } - caCert, err := x509.ParseCertificate(caCertBlock.Bytes) - if err != nil { - fmt.Println("Failed to parse CA certificate:", err) - errors++ - continue - } - caKeyBlock, _ := pem.Decode(caKeyPEM) - if caKeyBlock == nil { - fmt.Println("Failed to parse CA key PEM") - errors++ - continue - } - caKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes) - if err != nil { - fmt.Println("Failed to parse CA private key:", err) - errors++ - continue - } - - priv, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - fmt.Println("Failed to generate private key:", err) - errors++ - continue - } - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - fmt.Println("Failed to generate serial number:", err) - errors++ - continue - } - - var validityDur time.Duration - if finalDef.Validity != "" { - validityDur, err = parseValidity(finalDef.Validity) - if err != nil { - fmt.Println("Invalid validity value:", err) - errors++ - continue - } } else { - validityDur = 365 * 24 * time.Hour // default 1 year - } - - var subjectPKIX pkix.Name - if isDNFormat(finalDef.Subject) { - subjectPKIX = parseDistinguishedName(finalDef.Subject) - } else { - subjectPKIX = pkix.Name{CommonName: finalDef.Subject} - } - - certTmpl := x509.Certificate{ - SerialNumber: serialNumber, - Subject: subjectPKIX, - NotBefore: time.Now(), - NotAfter: time.Now().Add(validityDur), - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, - } - - for _, s := range finalDef.SAN { - sLower := strings.ToLower(s) - var val string - if n, _ := fmt.Sscanf(sLower, "dns:%s", &val); n == 1 { - certTmpl.DNSNames = append(certTmpl.DNSNames, val) - } else if n, _ := fmt.Sscanf(sLower, "ip:%s", &val); n == 1 { - certTmpl.IPAddresses = append(certTmpl.IPAddresses, net.ParseIP(val)) - } else if n, _ := fmt.Sscanf(sLower, "email:%s", &val); n == 1 { - certTmpl.EmailAddresses = append(certTmpl.EmailAddresses, val) - } else { - fmt.Printf("Invalid SAN format: %s\n", s) - errors++ - continue + if !verbose { + fmt.Printf("done\n") } + successes++ } - - switch finalDef.Type { - case "client": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} - case "server": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth} - case "server-only": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} - case "code-signing": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning} - case "email": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageEmailProtection} - default: - fmt.Println("Unknown certificate type. Use one of: client, server, server-only, code-signing, email.") - errors++ - continue - } - - certDER, err := x509.CreateCertificate(rand.Reader, &certTmpl, caCert, &priv.PublicKey, caKey) - if err != nil { - fmt.Println("Failed to create certificate:", err) - errors++ - continue - } - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) - if err := SavePEM(certFile, certPEM, false, overwrite); err != nil { - fmt.Println("Error saving certificate:", err) - errors++ - continue - } - if err := SavePEM(keyFile, keyPEM, true, overwrite); err != nil { - fmt.Println("Error saving key:", err) - errors++ - continue - } - if !verbose { - fmt.Printf("done\n") - } - successes++ } fmt.Printf("Batch complete: %d succeeded, %d failed.\n", successes, errors) // Save CA state after batch issuance caDir := filepath.Dir(configPath) - caLabel := ca.Label + caLabel := CAConfig.Label caStatePath := filepath.Join(caDir, caLabel+"_state.json") - if err := SaveCAState(caStatePath, GlobalCAState); err != nil { + if err := SaveCAState(caStatePath, CAState); err != nil { fmt.Printf("Error saving CA state: %v\n", err) } return } - // Simple mode - subjectName := subject - if subjectName == "" { - subjectName = name - } + // Single mode finalDef := renderCertificateDefTemplates(CertificateDefinition{Name: name, Subject: subject, Type: certType, Validity: validity, SAN: san}, nil) - if verbose { - fmt.Printf("\nCertificate:\n") - fmt.Printf(" Name: %s\n", finalDef.Name) - fmt.Printf(" Subject: %s\n", finalDef.Subject) - fmt.Printf(" Type: %s\n", finalDef.Type) - fmt.Printf(" Validity: %s\n", finalDef.Validity) - fmt.Printf(" SAN: %v\n", finalDef.SAN) + if dryRun { + fmt.Printf("Would issue %s certificate for '%s' (dry run)\n", finalDef.Type, finalDef.Subject) + return } - // Inline the logic from internalIssueCertificate here - // Add default dns SAN for server/server-only if none specified - if (finalDef.Type == "server" || finalDef.Type == "server-only") && len(finalDef.SAN) == 0 { - finalDef.SAN = append(finalDef.SAN, "dns:"+finalDef.Subject) - } - - ca, err = LoadCA(configPath) + err := issueSingleCertificate(finalDef, overwrite, verbose) if err != nil { - fmt.Println("Error loading config:", err) - return - } - - caCertPath := filepath.Join(ca.Paths.Certificates, "ca_cert.pem") - caKeyPath := filepath.Join(ca.Paths.PrivateKeys, "ca_key.pem") - - caCertPEM, err := os.ReadFile(caCertPath) - if err != nil { - fmt.Println("Error reading CA certificate file:", err) - return - } - caKeyPEM, err := os.ReadFile(caKeyPath) - if err != nil { - fmt.Println("Error reading CA key file:", err) - return - } - - caCertBlock, _ := pem.Decode(caCertPEM) - if caCertBlock == nil { - fmt.Println("Failed to parse CA certificate PEM") - return - } - caCert, err := x509.ParseCertificate(caCertBlock.Bytes) - if err != nil { - fmt.Println("Failed to parse CA certificate:", err) - return - } - caKeyBlock, _ := pem.Decode(caKeyPEM) - if caKeyBlock == nil { - fmt.Println("Failed to parse CA key PEM") - return - } - caKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes) - if err != nil { - fmt.Println("Failed to parse CA private key:", err) - return - } - - priv, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - fmt.Println("Failed to generate private key:", err) - return - } - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - fmt.Println("Failed to generate serial number:", err) - return - } - - var validityDur time.Duration - if finalDef.Validity != "" { - validityDur, err = parseValidity(finalDef.Validity) - if err != nil { - fmt.Println("Invalid validity value:", err) - return - } - } else { - validityDur = 365 * 24 * time.Hour // default 1 year - } - - // Parse subject as DN if it looks like a DN, otherwise use as CommonName only - var subjectPKIX pkix.Name - if isDNFormat(finalDef.Subject) { - subjectPKIX = parseDistinguishedName(finalDef.Subject) - } else { - subjectPKIX = pkix.Name{CommonName: finalDef.Subject} - } - - certTmpl := x509.Certificate{ - SerialNumber: serialNumber, - Subject: subjectPKIX, - NotBefore: time.Now(), - NotAfter: time.Now().Add(validityDur), - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, - } - - // Handle SANs - for _, s := range finalDef.SAN { - sLower := strings.ToLower(s) - var val string - if n, _ := fmt.Sscanf(sLower, "dns:%s", &val); n == 1 { - certTmpl.DNSNames = append(certTmpl.DNSNames, val) - } else if n, _ := fmt.Sscanf(sLower, "ip:%s", &val); n == 1 { - certTmpl.IPAddresses = append(certTmpl.IPAddresses, net.ParseIP(val)) - } else if n, _ := fmt.Sscanf(sLower, "email:%s", &val); n == 1 { - certTmpl.EmailAddresses = append(certTmpl.EmailAddresses, val) - } else { - fmt.Printf("Invalid SAN format: %s\n", s) - return - } - } - - switch finalDef.Type { - case "client": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} - case "server": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth} - case "server-only": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} - case "code-signing": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning} - case "email": - certTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageEmailProtection} - default: - fmt.Println("Unknown certificate type. Use one of: client, server, server-only, code-signing, email.") - return - } - - certDER, err := x509.CreateCertificate(rand.Reader, &certTmpl, caCert, &priv.PublicKey, caKey) - if err != nil { - fmt.Println("Failed to create certificate:", err) - return - } - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) - - basename := finalDef.Name - if basename == "" { - basename = finalDef.Subject - } - certFile := filepath.Join(ca.Paths.Certificates, basename+"."+finalDef.Type+".crt.pem") - keyFile := filepath.Join(ca.Paths.PrivateKeys, basename+"."+finalDef.Type+".key.pem") - if err := SavePEM(certFile, certPEM, false, overwrite); err != nil { - fmt.Println("Error saving certificate:", err) - return - } - if err := SavePEM(keyFile, keyPEM, true, overwrite); err != nil { - fmt.Println("Error saving key:", err) + fmt.Printf("Error: %v\n", err) return } fmt.Printf("%s certificate and key for '%s' generated.\n", finalDef.Type, finalDef.Subject) // Save CA state after single issuance caDir := filepath.Dir(configPath) - caLabel := ca.Label + caLabel := CAConfig.Label caStatePath := filepath.Join(caDir, caLabel+"_state.json") - if err := SaveCAState(caStatePath, GlobalCAState); err != nil { + if err := SaveCAState(caStatePath, CAState); err != nil { fmt.Printf("Error saving CA state: %v\n", err) } } diff --git a/certdb.go b/certdb.go index 1566ee6..e4b2e01 100644 --- a/certdb.go +++ b/certdb.go @@ -8,9 +8,9 @@ import ( "time" ) -// CAState represents the persisted CA state in JSON +// _CAState represents the persisted CA state in JSON // (matches the structure of example_ca.json) -type CAState struct { +type _CAState struct { CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"` Serial int `json:"serial,omitempty"` @@ -26,13 +26,13 @@ type CertificateRecord struct { } // LoadCAState loads the CA state from a JSON file -func LoadCAState(filename string) (*CAState, error) { +func LoadCAState(filename string) (*_CAState, error) { f, err := os.Open(filename) if err != nil { return nil, err } defer f.Close() - var state CAState + var state _CAState if err := json.NewDecoder(f).Decode(&state); err != nil { return nil, err } @@ -40,7 +40,7 @@ func LoadCAState(filename string) (*CAState, error) { } // SaveCAState saves the CA state to a JSON file -func SaveCAState(filename string, state *CAState) error { +func SaveCAState(filename string, state *_CAState) error { state.UpdatedAt = time.Now().UTC().Format(time.RFC3339) f, err := os.Create(filename) if err != nil { @@ -55,14 +55,14 @@ func SaveCAState(filename string, state *CAState) error { // UpdateCAStateAfterIssue updates the CA state JSON after issuing a certificate func UpdateCAStateAfterIssue(jsonFile, serialType, basename string, serialNumber any, validity time.Duration) error { var err error - if GlobalCAState == nil { - GlobalCAState, err = LoadCAState(jsonFile) + if CAState == nil { + CAState, err = LoadCAState(jsonFile) if err != nil { - GlobalCAState = nil + CAState = nil } } - if GlobalCAState == nil { - fmt.Fprintf(os.Stderr, "FATAL: GlobalCAState is nil in UpdateCAStateAfterIssue. This indicates a programming error.\n") + if CAState == nil { + fmt.Fprintf(os.Stderr, "FATAL: CAState is nil in UpdateCAStateAfterIssue. This indicates a programming error.\n") os.Exit(1) } issued := time.Now().UTC().Format(time.RFC3339) @@ -70,8 +70,8 @@ func UpdateCAStateAfterIssue(jsonFile, serialType, basename string, serialNumber serialStr := "" switch serialType { case "sequential": - serialStr = fmt.Sprintf("%d", GlobalCAState.Serial) - GlobalCAState.Serial++ + serialStr = fmt.Sprintf("%d", CAState.Serial) + CAState.Serial++ case "random": serialStr = fmt.Sprintf("%x", serialNumber) default: @@ -81,10 +81,10 @@ func UpdateCAStateAfterIssue(jsonFile, serialType, basename string, serialNumber return nil } -// AddCertificate appends a new CertificateRecord to the GlobalCAState +// AddCertificate appends a new CertificateRecord to the CAState func AddCertificate(name, issued, expires, serial string, valid bool) { - if GlobalCAState == nil { - fmt.Fprintf(os.Stderr, "FATAL: GlobalCAState is nil in AddCertificate. This indicates a programming error.\n") + if CAState == nil { + fmt.Fprintf(os.Stderr, "FATAL: CAState is nil in AddCertificate. This indicates a programming error.\n") os.Exit(1) } rec := CertificateRecord{ @@ -94,5 +94,7 @@ func AddCertificate(name, issued, expires, serial string, valid bool) { Serial: serial, Valid: valid, } - GlobalCAState.Certificates = append(GlobalCAState.Certificates, rec) + CAState.Certificates = append(CAState.Certificates, rec) } + +// No CAConfig references to update in this file diff --git a/main.go b/main.go index 6132a4f..f990c59 100644 --- a/main.go +++ b/main.go @@ -9,9 +9,6 @@ import ( var Version = "dev" -// Global CA state variable -var GlobalCAState *CAState - func main() { var configPath string var overwrite bool @@ -37,6 +34,10 @@ func main() { Use: "initca", Short: "Generate a new CA certificate and key", Run: func(cmd *cobra.Command, args []string) { + if err := LoadCA(configPath); err != nil { + fmt.Printf("Error loading CA config: %v\n", err) + os.Exit(1) + } InitCA(configPath, overwrite) }, } @@ -48,6 +49,10 @@ func main() { Use: "issue", Short: "Issue a new certificate (client, server, server-only, code-signing, email)", Run: func(cmd *cobra.Command, args []string) { + if err := LoadCA(configPath); err != nil { + fmt.Printf("Error loading CA config: %v\n", err) + os.Exit(1) + } IssueCertificate(configPath, subject, certType, validity, san, name, fromFile, overwrite, dryRun, verbose) }, }