gitlab-org--gitlab-foss/workhorse/internal/redis/redis.go

295 lines
8.1 KiB
Go

package redis
import (
"errors"
"fmt"
"net"
"net/url"
"time"
"github.com/FZambia/sentinel"
"github.com/gomodule/redigo/redis"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
)
var (
pool *redis.Pool
sntnl *sentinel.Sentinel
)
const (
// Max Idle Connections in the pool.
defaultMaxIdle = 1
// Max Active Connections in the pool.
defaultMaxActive = 1
// Timeout for Read operations on the pool. 1 second is technically overkill,
// it's just for sanity.
defaultReadTimeout = 1 * time.Second
// Timeout for Write operations on the pool. 1 second is technically overkill,
// it's just for sanity.
defaultWriteTimeout = 1 * time.Second
// Timeout before killing Idle connections in the pool. 3 minutes seemed good.
// If you _actually_ hit this timeout often, you should consider turning of
// redis-support since it's not necessary at that point...
defaultIdleTimeout = 3 * time.Minute
// KeepAlivePeriod is to keep a TCP connection open for an extended period of
// time without being killed. This is used both in the pool, and in the
// worker-connection.
// See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more
// information.
defaultKeepAlivePeriod = 5 * time.Minute
)
var (
totalConnections = promauto.NewCounter(
prometheus.CounterOpts{
Name: "gitlab_workhorse_redis_total_connections",
Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process",
},
)
errorCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gitlab_workhorse_redis_errors",
Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)",
},
[]string{"type", "dst"},
)
)
func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel {
if len(urls) == 0 {
return nil
}
var addrs []string
for _, url := range urls {
h := url.URL.String()
log.WithFields(log.Fields{
"scheme": url.URL.Scheme,
"host": url.URL.Host,
}).Printf("redis: using sentinel")
addrs = append(addrs, h)
}
return &sentinel.Sentinel{
Addrs: addrs,
MasterName: master,
Dial: func(addr string) (redis.Conn, error) {
// This timeout is recommended for Sentinel-support according to the guidelines.
// https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel
// For every address it should try to connect to the Sentinel,
// using a short timeout (in the order of a few hundreds of milliseconds).
timeout := 500 * time.Millisecond
url := helper.URLMustParse(addr)
var c redis.Conn
var err error
options := []redis.DialOption{
redis.DialConnectTimeout(timeout),
redis.DialReadTimeout(timeout),
redis.DialWriteTimeout(timeout),
}
if url.Scheme == "redis" || url.Scheme == "rediss" {
c, err = redis.DialURL(addr, options...)
} else {
c, err = redis.Dial("tcp", url.Host, options...)
}
if err != nil {
errorCounter.WithLabelValues("dial", "sentinel").Inc()
return nil, err
}
return c, nil
},
}
}
var poolDialFunc func() (redis.Conn, error)
var workerDialFunc func() (redis.Conn, error)
func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption {
readTimeout := defaultReadTimeout
writeTimeout := defaultWriteTimeout
if cfg != nil {
if cfg.ReadTimeout != nil {
readTimeout = cfg.ReadTimeout.Duration
}
if cfg.WriteTimeout != nil {
writeTimeout = cfg.WriteTimeout.Duration
}
}
return []redis.DialOption{
redis.DialReadTimeout(readTimeout),
redis.DialWriteTimeout(writeTimeout),
}
}
func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption {
var dopts []redis.DialOption
if setTimeouts {
dopts = timeoutDialOptions(cfg)
}
if cfg == nil {
return dopts
}
if cfg.Password != "" {
dopts = append(dopts, redis.DialPassword(cfg.Password))
}
if cfg.DB != nil {
dopts = append(dopts, redis.DialDatabase(*cfg.DB))
}
return dopts
}
func keepAliveDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
return func(network, address string) (net.Conn, error) {
addr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
tc, err := net.DialTCP(network, nil, addr)
if err != nil {
return nil, err
}
if err := tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err := tc.SetKeepAlivePeriod(timeout); err != nil {
return nil, err
}
return tc, nil
}
}
type redisDialerFunc func() (redis.Conn, error)
func sentinelDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration) redisDialerFunc {
return func() (redis.Conn, error) {
address, err := sntnl.MasterAddr()
if err != nil {
errorCounter.WithLabelValues("master", "sentinel").Inc()
return nil, err
}
dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
return redisDial("tcp", address, dopts...)
}
}
func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url url.URL) redisDialerFunc {
return func() (redis.Conn, error) {
if url.Scheme == "unix" {
return redisDial(url.Scheme, url.Path, dopts...)
}
dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
// redis.DialURL only works with redis[s]:// URLs
if url.Scheme == "redis" || url.Scheme == "rediss" {
return redisURLDial(url, dopts...)
}
return redisDial(url.Scheme, url.Host, dopts...)
}
}
func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) {
log.WithFields(log.Fields{
"scheme": url.Scheme,
"address": url.Host,
}).Printf("redis: dialing")
return redis.DialURL(url.String(), options...)
}
func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) {
log.WithFields(log.Fields{
"network": network,
"address": address,
}).Printf("redis: dialing")
return redis.Dial(network, address, options...)
}
func countDialer(dialer redisDialerFunc) redisDialerFunc {
return func() (redis.Conn, error) {
c, err := dialer()
if err != nil {
errorCounter.WithLabelValues("dial", "redis").Inc()
} else {
totalConnections.Inc()
}
return c, err
}
}
// DefaultDialFunc should always used. Only exception is for unit-tests.
func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) {
keepAlivePeriod := defaultKeepAlivePeriod
if cfg.KeepAlivePeriod != nil {
keepAlivePeriod = cfg.KeepAlivePeriod.Duration
}
dopts := dialOptionsBuilder(cfg, setReadTimeout)
if sntnl != nil {
return countDialer(sentinelDialer(dopts, keepAlivePeriod))
}
return countDialer(defaultDialer(dopts, keepAlivePeriod, cfg.URL.URL))
}
// Configure redis-connection
func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) {
if cfg == nil {
return
}
maxIdle := defaultMaxIdle
if cfg.MaxIdle != nil {
maxIdle = *cfg.MaxIdle
}
maxActive := defaultMaxActive
if cfg.MaxActive != nil {
maxActive = *cfg.MaxActive
}
sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel)
workerDialFunc = dialFunc(cfg, false)
poolDialFunc = dialFunc(cfg, true)
pool = &redis.Pool{
MaxIdle: maxIdle, // Keep at most X hot connections
MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited
IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed
Dial: poolDialFunc,
Wait: true,
}
if sntnl != nil {
pool.TestOnBorrow = func(c redis.Conn, t time.Time) error {
if !sentinel.TestRole(c, "master") {
return errors.New("role check failed")
}
return nil
}
}
}
// Get a connection for the Redis-pool
func Get() redis.Conn {
if pool != nil {
return pool.Get()
}
return nil
}
// GetString fetches the value of a key in Redis as a string
func GetString(key string) (string, error) {
conn := Get()
if conn == nil {
return "", fmt.Errorf("redis: could not get connection from pool")
}
defer conn.Close()
return redis.String(conn.Do("GET", key))
}