amfora/client/tofu.go

138 lines
4.2 KiB
Go

package client
import (
"crypto/sha256"
"crypto/x509"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/makeworld-the-better-one/amfora/config"
)
// TOFU implementation.
// Stores cert hash and expiry for now, like Bombadillo.
// There is ongoing TOFU discussion on the mailing list about better
// ways to do this, and I will update this file once those are decided on.
// Update: See #7 for some small improvements made.
var ErrTofu = errors.New("server cert does not match TOFU database")
var tofuStore = config.TofuStore
// tofuStoreMu protects tofuStore, since viper is not thread-safe.
// See this issue for details: https://github.com/spf13/viper/issues/268
// This is needed because Gemini requests may happen concurrently and
// call on the funcs on this file.
var tofuStoreMu = sync.RWMutex{}
// idKey returns the config/viper key needed to retrieve
// a cert's ID / fingerprint.
func idKey(domain string, port string) string {
if port == "1965" || port == "" {
return strings.ReplaceAll(domain, ".", "/")
}
return strings.ReplaceAll(domain, ".", "/") + ":" + port
}
func expiryKey(domain string, port string) string {
if port == "1965" || port == "" {
return strings.ReplaceAll(strings.TrimSuffix(domain, "."), ".", "/") + "/expiry"
}
return strings.ReplaceAll(strings.TrimSuffix(domain, "."), ".", "/") + "/expiry" + ":" + port
}
func loadTofuEntry(domain string, port string) (string, time.Time, error) {
tofuStoreMu.RLock()
defer tofuStoreMu.RUnlock()
id := tofuStore.GetString(idKey(domain, port)) // Fingerprint
if len(id) != sha256.Size*2 {
// Not set, or invalid
return "", time.Time{}, errors.New("not found") //nolint:goerr113
}
expiry := tofuStore.GetTime(expiryKey(domain, port))
if expiry.IsZero() {
// Not set
return id, time.Time{}, errors.New("not found") //nolint:goerr113
}
return id, expiry, nil
}
// certID returns a generic string representing a cert or domain.
func certID(cert *x509.Certificate) string {
h := sha256.New()
h.Write(cert.RawSubjectPublicKeyInfo) // Better than cert.Raw, see #7
return fmt.Sprintf("%X", h.Sum(nil))
}
// origCertID uses cert.Raw, which was used in v1.0.0 of the app.
func origCertID(cert *x509.Certificate) string {
h := sha256.New()
h.Write(cert.Raw)
return fmt.Sprintf("%X", h.Sum(nil))
}
func saveTofuEntry(domain, port string, cert *x509.Certificate) {
tofuStoreMu.Lock()
defer tofuStoreMu.Unlock()
tofuStore.Set(idKey(domain, port), certID(cert))
tofuStore.Set(expiryKey(domain, port), cert.NotAfter.UTC())
tofuStore.WriteConfig() //nolint:errcheck // Not an issue if it's not saved, only cached data
}
// handleTofu is the abstracted interface for taking care of TOFU.
// A cert is provided and storage, checking, etc, are taken care of.
// It returns a bool indicating if the cert is valid according to
// the TOFU database.
// If false is returned, the connection should not go ahead.
func handleTofu(domain, port string, cert *x509.Certificate) bool {
id, expiry, err := loadTofuEntry(domain, port)
if err != nil {
// Cert isn't in database or data is malformed
// So it can't be checked and anything is valid
saveTofuEntry(domain, port, cert)
return true
}
if certID(cert) == id {
// Same cert as the one stored
// Store expiry again in case it changed
tofuStoreMu.Lock()
tofuStore.Set(expiryKey(domain, port), cert.NotAfter.UTC())
tofuStore.WriteConfig() //nolint:errcheck
tofuStoreMu.Unlock()
return true
}
if origCertID(cert) == id {
// Valid but uses old ID type
saveTofuEntry(domain, port, cert)
return true
}
if time.Now().After(expiry) {
// Old cert expired, so anything is valid
saveTofuEntry(domain, port, cert)
return true
}
return false
}
// ResetTofuEntry forces the cert passed to be valid, overwriting any previous TOFU entry.
// The port string can be empty, to indicate port 1965.
func ResetTofuEntry(domain, port string, cert *x509.Certificate) {
saveTofuEntry(domain, port, cert)
}
// GetExpiry returns the stored expiry date for the given host.
// The time will be empty (zero) if there is not expiry date stored for that host.
func GetExpiry(domain, port string) time.Time {
tofuStoreMu.RLock()
defer tofuStoreMu.RUnlock()
return tofuStore.GetTime(expiryKey(domain, port))
}