Files

275 lines
8.0 KiB
Go

// Copyright (c) 2026 Sławomir Koszewski. All rights reserved.
// Use of this source code is governed by the MIT License
// that can be found in the LICENSE file.
package miab
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestNewClient_miabStyle(t *testing.T) {
t.Setenv("MIAB_HOST", "box.example.com")
t.Setenv("MIAB_USERNAME", "user@example.com")
t.Setenv("MIAB_PASSWORD", "secret")
c, err := NewClient("", "", "")
if err != nil {
t.Fatal(err)
}
if c.host != "box.example.com" || c.username != "user@example.com" || c.password != "secret" {
t.Errorf("unexpected client fields: %+v", c)
}
}
func TestNewClient_mailinaboxStyle(t *testing.T) {
t.Setenv("MAILINABOX_BASE_URL", "https://box.example.com")
t.Setenv("MAILINABOX_EMAIL", "user@example.com")
t.Setenv("MAILINABOX_PASSWORD", "secret")
c, err := NewClient("", "", "")
if err != nil {
t.Fatal(err)
}
if c.host != "box.example.com" || c.username != "user@example.com" || c.password != "secret" {
t.Errorf("unexpected client fields: %+v", c)
}
}
func TestNewClient_miabTakesPrecedence(t *testing.T) {
t.Setenv("MIAB_HOST", "miab.example.com")
t.Setenv("MIAB_USERNAME", "miab@example.com")
t.Setenv("MIAB_PASSWORD", "miabpass")
t.Setenv("MAILINABOX_BASE_URL", "https://other.example.com")
t.Setenv("MAILINABOX_EMAIL", "other@example.com")
t.Setenv("MAILINABOX_PASSWORD", "otherpass")
c, err := NewClient("", "", "")
if err != nil {
t.Fatal(err)
}
if c.host != "miab.example.com" || c.username != "miab@example.com" || c.password != "miabpass" {
t.Errorf("MIAB_* vars should take precedence, got: %+v", c)
}
}
func TestNewClient_missingVars(t *testing.T) {
_, err := NewClient("", "", "")
if err == nil {
t.Fatal("expected error when no env vars set")
}
}
func TestNewClient_explicitParamsTakePrecedence(t *testing.T) {
t.Setenv("MIAB_HOST", "env.example.com")
t.Setenv("MIAB_USERNAME", "env@example.com")
t.Setenv("MIAB_PASSWORD", "envpass")
c, err := NewClient("explicit.example.com", "explicit@example.com", "explicitpass")
if err != nil {
t.Fatal(err)
}
if c.host != "explicit.example.com" || c.username != "explicit@example.com" || c.password != "explicitpass" {
t.Errorf("explicit params should take precedence, got: %+v", c)
}
}
func TestBuildURL(t *testing.T) {
c, _ := NewClient("box.example.com", "user", "pass")
tests := []struct {
name, recordType, want string
}{
{"example.com", "A", "https://box.example.com/admin/dns/custom/example.com"},
{"example.com", "a", "https://box.example.com/admin/dns/custom/example.com"},
{"example.com", "TXT", "https://box.example.com/admin/dns/custom/example.com/TXT"},
{"example.com", "txt", "https://box.example.com/admin/dns/custom/example.com/TXT"},
{"example.com", "MX", "https://box.example.com/admin/dns/custom/example.com/MX"},
}
for _, tt := range tests {
got := c.buildURL(tt.name, tt.recordType)
if got != tt.want {
t.Errorf("buildURL(%q, %q) = %q; want %q", tt.name, tt.recordType, got, tt.want)
}
}
}
func TestSetRecord(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
t.Errorf("expected PUT, got %s", r.Method)
}
if r.URL.Path != "/admin/dns/custom/test.example.com/TXT" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
assertBasicAuth(t, r, "user@example.com", "secret")
body, _ := io.ReadAll(r.Body)
if string(body) != "hello" {
t.Errorf("unexpected body: %s", body)
}
fmt.Fprintln(w, "OK")
}))
defer srv.Close()
c := testClient(srv)
if err := c.SetRecord("test.example.com", "TXT", "hello"); err != nil {
t.Fatal(err)
}
}
func TestDeleteRecord(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
t.Errorf("expected DELETE, got %s", r.Method)
}
if r.URL.Path != "/admin/dns/custom/test.example.com/TXT" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
assertBasicAuth(t, r, "user@example.com", "secret")
body, _ := io.ReadAll(r.Body)
if string(body) != "hello" {
t.Errorf("unexpected body: %s", body)
}
fmt.Fprintln(w, "OK")
}))
defer srv.Close()
c := testClient(srv)
if err := c.DeleteRecord("test.example.com", "TXT", "hello"); err != nil {
t.Fatal(err)
}
}
func TestAddRecord(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/admin/dns/custom/test.example.com" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
assertBasicAuth(t, r, "user@example.com", "secret")
fmt.Fprintln(w, "OK")
}))
defer srv.Close()
c := testClient(srv)
if err := c.AddRecord("test.example.com", "A", "1.2.3.4"); err != nil {
t.Fatal(err)
}
}
func TestListRecords_all(t *testing.T) {
raw := []map[string]string{
{"qname": "a.example.com", "rtype": "A", "value": "1.2.3.4"},
{"qname": "b.example.com", "rtype": "TXT", "value": "v=spf1"},
}
srv := jsonServer(t, raw)
defer srv.Close()
c := testClient(srv)
records, err := c.ListRecords("")
if err != nil {
t.Fatal(err)
}
if len(records) != 2 {
t.Fatalf("expected 2 records, got %d", len(records))
}
}
func TestListRecords_filtered(t *testing.T) {
raw := []map[string]string{
{"qname": "a.example.com", "rtype": "A", "value": "1.2.3.4"},
{"qname": "b.example.com", "rtype": "TXT", "value": "v=spf1"},
}
srv := jsonServer(t, raw)
defer srv.Close()
c := testClient(srv)
records, err := c.ListRecords("TXT")
if err != nil {
t.Fatal(err)
}
if len(records) != 1 {
t.Fatalf("expected 1 TXT record, got %d", len(records))
}
if records[0].Name != "b.example.com" {
t.Errorf("unexpected record name: %s", records[0].Name)
}
}
func TestErrorResponse(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}))
defer srv.Close()
c := testClient(srv)
err := c.SetRecord("test.example.com", "TXT", "val")
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "401") {
t.Errorf("expected 401 in error, got: %v", err)
}
}
// testClient builds a Client pointed at the test server's host.
func testClient(srv *httptest.Server) *Client {
c, err := NewClient(strings.TrimPrefix(srv.URL, "http://"), "user@example.com", "secret")
if err != nil {
panic(err)
}
c.httpClient = srv.Client()
// Override buildURL to use http instead of https for the test server.
// We do this by wrapping the httpClient transport (no-op) and patching
// the host so the URL function produces http:// URLs.
// Simplest approach: override the host to include the scheme directly.
// Since buildURL always prepends "https://", we use a custom transport
// that rewrites the scheme.
c.httpClient.Transport = &schemeRewriter{underlying: http.DefaultTransport, targetURL: srv.URL}
return c
}
// schemeRewriter is a RoundTripper that rewrites the request URL scheme+host
// to point at the test server, allowing testClient to use the real buildURL logic.
type schemeRewriter struct {
underlying http.RoundTripper
targetURL string // e.g. "http://127.0.0.1:PORT"
}
func (s *schemeRewriter) RoundTrip(req *http.Request) (*http.Response, error) {
clone := req.Clone(req.Context())
target := strings.TrimPrefix(s.targetURL, "http://")
clone.URL.Scheme = "http"
clone.URL.Host = target
return s.underlying.RoundTrip(clone)
}
func jsonServer(t *testing.T, payload any) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assertBasicAuth(t, r, "user@example.com", "secret")
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(payload)
}))
}
func assertBasicAuth(t *testing.T, r *http.Request, username, password string) {
t.Helper()
u, p, ok := r.BasicAuth()
if !ok {
t.Error("expected Basic Auth header")
}
if u != username || p != password {
t.Errorf("auth mismatch: got %s/%s, want %s/%s", u, p, username, password)
}
}