package main import ( "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "math/big" "os" "path/filepath" "time" gohcl "github.com/hashicorp/hcl/v2/gohcl" hclparse "github.com/hashicorp/hcl/v2/hclparse" ) type Paths struct { Certificates string `hcl:"certificates"` PrivateKeys string `hcl:"private_keys"` } type CA struct { Label string `hcl:",label"` Name string `hcl:"name"` Country string `hcl:"country"` Organization string `hcl:"organization"` SerialType string `hcl:"serial_type"` KeySize int `hcl:"key_size,optional"` Validity string `hcl:"validity,optional"` Paths Paths `hcl:"paths,block"` } type Configuration struct { CA CA `hcl:"ca,block"` } func LoadCA(path string) (*CA, error) { parser := hclparse.NewParser() file, diags := parser.ParseHCLFile(path) if diags.HasErrors() { return nil, fmt.Errorf("failed to parse HCL: %s", diags.Error()) } var config Configuration diags = gohcl.DecodeBody(file.Body, nil, &config) if diags.HasErrors() { return nil, fmt.Errorf("failed to decode HCL: %s", diags.Error()) } if (CA{}) == config.CA { return nil, fmt.Errorf("no 'ca' block found in config file") } if err := config.CA.Validate(); err != nil { return nil, err } return &config.CA, nil } func parseValidity(validity string) (time.Duration, error) { if validity == "" { return time.Hour * 24 * 365 * 5, nil // default 5 years } var n int var unit rune _, err := fmt.Sscanf(validity, "%d%c", &n, &unit) if err != nil { // If no unit, assume years _, err2 := fmt.Sscanf(validity, "%d", &n) if err2 != nil { return 0, fmt.Errorf("invalid validity format: %s", validity) } unit = 'y' } switch unit { case 'y': return time.Hour * 24 * 365 * time.Duration(n), nil case 'm': return time.Hour * 24 * 30 * time.Duration(n), nil case 'd': return time.Hour * 24 * time.Duration(n), nil default: return 0, fmt.Errorf("invalid validity unit: %c", unit) } } func GenerateCA(ca *CA) ([]byte, []byte, error) { keySize := ca.KeySize if keySize == 0 { keySize = 4096 } priv, err := rsa.GenerateKey(rand.Reader, keySize) if err != nil { return nil, nil, err } serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return nil, nil, fmt.Errorf("failed to generate serial number: %v", err) } validity, err := parseValidity(ca.Validity) if err != nil { return nil, nil, err } now := time.Now() tmpl := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ Country: []string{ca.Country}, Organization: []string{ca.Organization}, CommonName: ca.Name, }, NotBefore: now, NotAfter: now.Add(validity), KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, BasicConstraintsValid: true, IsCA: true, } certDER, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) if err != nil { return nil, nil, err } certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) return certPEM, keyPEM, nil } func SavePEM(filename string, data []byte, secure bool, overwrite bool) error { if !overwrite { if _, err := os.Stat(filename); err == nil { return fmt.Errorf("file %s already exists (overwrite not allowed)", filename) } else if !os.IsNotExist(err) { return fmt.Errorf("could not check file %s: %v", filename, err) } } if secure { return os.WriteFile(filename, data, 0600) } else { return os.WriteFile(filename, data, 0644) } } func (p *Paths) Validate() error { if p.Certificates == "" { return fmt.Errorf("paths.certificates is required") } if p.PrivateKeys == "" { return fmt.Errorf("paths.private_keys is required") } return nil } func (c *CA) Validate() error { if c.Name == "" { return fmt.Errorf("CA 'name' is required") } if c.Country == "" { return fmt.Errorf("CA 'country' is required") } if c.Organization == "" { return fmt.Errorf("CA 'organization' is required") } if c.SerialType == "" { c.SerialType = "random" } if c.SerialType != "random" && c.SerialType != "sequential" { return fmt.Errorf("CA 'serial_type' must be 'random' or 'sequential'") } if err := c.Paths.Validate(); err != nil { return err } return nil } func InitCA(configPath string, overwrite bool) { ca, err := LoadCA(configPath) if err != nil { fmt.Println("Error loading config:", err) return } // Create certificates directory with 0755, private keys with 0700 if ca.Paths.Certificates != "" { if err := os.MkdirAll(ca.Paths.Certificates, 0755); err != nil { fmt.Printf("Error creating certificates directory '%s': %v\n", ca.Paths.Certificates, err) return } } if ca.Paths.PrivateKeys != "" { if err := os.MkdirAll(ca.Paths.PrivateKeys, 0700); err != nil { fmt.Printf("Error creating private keys directory '%s': %v\n", ca.Paths.PrivateKeys, err) return } } certPEM, keyPEM, err := GenerateCA(ca) if err != nil { fmt.Println("Error generating CA:", err) return } if err := SavePEM(filepath.Join(ca.Paths.Certificates, "ca_cert.pem"), certPEM, false, overwrite); err != nil { fmt.Println("Error saving CA certificate:", err) return } if err := SavePEM(filepath.Join(ca.Paths.PrivateKeys, "ca_key.pem"), keyPEM, true, overwrite); err != nil { fmt.Println("Error saving CA key:", err) return } fmt.Println("CA certificate and key generated.") } func IssueCertificate(configPath, subject string, overwrite bool) { ca, err := LoadCA(configPath) if err != nil { fmt.Println("Error loading config:", err) return } caCertPath := filepath.Join(ca.Paths.Certificates, "ca_cert.pem") caKeyPath := filepath.Join(ca.Paths.PrivateKeys, "ca_key.pem") caCertPEM, err := os.ReadFile(caCertPath) if err != nil { fmt.Println("Error reading CA certificate file:", err) return } caKeyPEM, err := os.ReadFile(caKeyPath) if err != nil { fmt.Println("Error reading CA key file:", err) return } caCertBlock, _ := pem.Decode(caCertPEM) if caCertBlock == nil { fmt.Println("Failed to parse CA certificate PEM") return } caCert, err := x509.ParseCertificate(caCertBlock.Bytes) if err != nil { fmt.Println("Failed to parse CA certificate:", err) return } caKeyBlock, _ := pem.Decode(caKeyPEM) if caKeyBlock == nil { fmt.Println("Failed to parse CA key PEM") return } caKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes) if err != nil { fmt.Println("Failed to parse CA private key:", err) return } priv, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { fmt.Println("Failed to generate private key:", err) return } serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { fmt.Println("Failed to generate serial number:", err) return } certTmpl := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ CommonName: subject, }, NotBefore: time.Now(), NotAfter: time.Now().AddDate(1, 0, 0), // 1 year KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, } certDER, err := x509.CreateCertificate(rand.Reader, &certTmpl, caCert, &priv.PublicKey, caKey) if err != nil { fmt.Println("Failed to create certificate:", err) return } certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) certFile := filepath.Join(ca.Paths.Certificates, subject+".crt.pem") keyFile := filepath.Join(ca.Paths.PrivateKeys, subject+".key.pem") if err := SavePEM(certFile, certPEM, false, overwrite); err != nil { fmt.Println("Error saving certificate:", err) return } if err := SavePEM(keyFile, keyPEM, true, overwrite); err != nil { fmt.Println("Error saving key:", err) return } fmt.Printf("Certificate and key for '%s' generated.\n", subject) }