Skip to content

令牌桶限流

  1. 并发控制与限流的区别

并发控制:限制同时执行的任务数量。 ants.NewPool(5) 就是限制最多 5 个协程同时跑。超过的请求只能排队等待空位。

限流:限制单位时间内的请求数量(QPS),比如每秒最多 10 次请求,超过就拒绝或等待。 典型的限流工具是令牌桶(rate.Limiter)或者漏桶。


  1. 并发控制与限流的核心目标
  • 并发控制:保护自己,防止系统资源被撑爆
    • 并发控制关心的不是“多久做多少”,而是“同时做多少”。
  • 限流:保护别人(或自己对接的服务),防止请求频率过高
    • 限流关注的是“单位时间内的数量”。

  1. 并发控制与限流的场景

并发控制的场景:

  • 爬虫/批量任务:一口气爬 10 万个 URL,不能开 10 万个 goroutine,不然内存炸了。限制同时最多跑 50 个任务,跑完一个再补一个。
  • 批量文件处理:同时处理的文件太多会占满 CPU/IO,直接卡死。限制一次最多并行处理 8 个文件。
  • 数据库并发写:太多连接同时写数据库,DB 连接池爆满,事务卡死。限制并发数,保证数据库响应正常。

限流的场景:

  • 调用三方 API:三方规定 10 QPS 超过就封号。用令牌桶让请求以稳定速率发出,不突刺。
  • 秒杀活动下单。系统承受不了瞬时高流量,要按 QPS 限制进入的下单请求。
  • 定时任务/调度器。任务很多,但 API 每秒只允许 100 个请求,要分批分时间发。

单机令牌桶限流

go
package main

import (
	"context"
	"fmt"
	"io/ioutil"
	"log/slog"
	"net/http"
	"sync"
	"time"

	"golang.org/x/time/rate"
)

// 限流器单例
var limiter = rate.NewLimiter(10, 20) // 10 QPS,桶容量20

// 统一请求函数,只管发请求,不管限流
func callThirdPartyAPI(url string) (string, error) {
	client := &http.Client{Timeout: 5 * time.Second}
	resp, err := client.Get(url)
	if err != nil {
		return "", err
	}
	defer resp.Body.Close()
	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return "", err
	}
	return string(body), nil
}

// 调用入口,限流和请求分离
func limitedCall(ctx context.Context, url string, idx int) (string, error) {
	slog.Info("等待令牌中", "请求序号", idx)
	if err := limiter.Wait(ctx); err != nil {
		slog.Error("等待令牌失败", "请求序号", idx, "错误", err)
		return "", err
	}
	slog.Info("拿到令牌,开始请求", "请求序号", idx)
	resp, err := callThirdPartyAPI(url)
	if err != nil {
		slog.Error("请求失败", "请求序号", idx, "错误", err)
		return "", err
	}
	slog.Info("请求成功", "请求序号", idx)
	return resp, nil
}

func main() {
	ctx := context.Background()
	url := "https://echo.hoppscotch.io"

	var wg sync.WaitGroup
	for i := 0; i < 50; i++ {
		wg.Add(1)
		go func(i int) {
			defer wg.Done()
			resp, err := limitedCall(ctx, url, i)
			if err != nil {
				fmt.Printf("请求%d失败: %v\n", i, err)
				return
			}
			fmt.Printf("请求%d响应: %s\n", i, resp)
		}(i)
	}
	wg.Wait()
}

分布式令牌桶限流

学习理解分布式令牌桶限流的好样本

go
package main

import (
    "context"
    "fmt"
    "log"
    "sync"
    "time"

    "github.com/redis/go-redis/v9"
    "github.com/google/uuid"
)

var (
    ctx         = context.Background()
    redisClient *redis.Client

    tokenKey    = "dist_token_bucket"
    maxTokens   = 20         // 桶最大容量
    refillRate  = 10         // 每秒补充令牌数
    lockKey     = "dist_token_bucket_refill_lock"
    lockTTL     = 2 * time.Second
)

const luaScript = `
local tokens = tonumber(redis.call("GET", KEYS[1]) or "-1")
if tokens <= 0 then
    return 0
else
    redis.call("DECR", KEYS[1])
    return 1
end
`

func initTokenBucket() error {
    val, err := redisClient.Get(ctx, tokenKey).Int()
    if err == redis.Nil || val < 0 {
        return redisClient.Set(ctx, tokenKey, maxTokens, 0).Err()
    }
    return err
}

// 获取分布式锁,成功返回锁值,失败空字符串
func acquireLock(key string, ttl time.Duration) (string, error) {
    lockValue := uuid.New().String()
    ok, err := redisClient.SetNX(ctx, key, lockValue, ttl).Result()
    if err != nil {
        return "", err
    }
    if !ok {
        return "", nil
    }
    return lockValue, nil
}

// 释放锁,防止误删,用value校验
func releaseLock(key, value string) error {
    lua := `
    if redis.call("GET", KEYS[1]) == ARGV[1] then
        return redis.call("DEL", KEYS[1])
    else
        return 0
    end
    `
    res, err := redisClient.Eval(ctx, lua, []string{key}, value).Result()
    if err != nil {
        return err
    }
    if res.(int64) == 0 {
        return fmt.Errorf("释放锁失败,锁已被释放或不是本人持有")
    }
    return nil
}

// 补充令牌函数,安全加锁保证单实例补充
func refillTokensWithLock() {
    ticker := time.NewTicker(time.Second)
    defer ticker.Stop()

    for range ticker.C {
        lockValue, err := acquireLock(lockKey, lockTTL)
        if err != nil {
            log.Printf("[refill] 获取锁出错: %v", err)
            continue
        }
        if lockValue == "" {
            // 获取锁失败,跳过本次补充,等待下一轮
            continue
        }

        // 有锁,开始补充令牌
        for i := 0; i < refillRate; i++ {
            err := redisClient.Watch(ctx, func(tx *redis.Tx) error {
                tokensStr, err := tx.Get(ctx, tokenKey).Result()
                if err != nil && err != redis.Nil {
                    return err
                }
                tokens := 0
                if tokensStr != "" {
                    fmt.Sscanf(tokensStr, "%d", &tokens)
                }
                if tokens < maxTokens {
                    _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
                        pipe.Incr(ctx, tokenKey)
                        return nil
                    })
                    return err
                }
                return nil
            }, tokenKey)

            if err != nil {
                log.Printf("[refill] 补充令牌出错: %v", err)
                break
            }
        }

        // 释放锁
        if err := releaseLock(lockKey, lockValue); err != nil {
            log.Printf("[refill] 释放锁失败: %v", err)
        }
    }
}

// 抢令牌,成功true,失败false
func acquireToken() (bool, error) {
    res, err := redisClient.Eval(ctx, luaScript, []string{tokenKey}).Int()
    if err != nil {
        return false, err
    }
    return res == 1, nil
}

func main() {
    redisClient = redis.NewClient(&redis.Options{
        Addr: "localhost:6379",
    })

    // 初始化桶容量
    if err := initTokenBucket(); err != nil {
        log.Fatalf("初始化令牌桶失败: %v", err)
    }

    // 启动补充令牌协程(多个实例同时跑也安全)
    go refillTokensWithLock()

    var wg sync.WaitGroup
    // 模拟50个请求并发抢令牌
    for i := 0; i < 50; i++ {
        wg.Add(1)
        go func(idx int) {
            defer wg.Done()
            ok, err := acquireToken()
            if err != nil {
                log.Printf("请求%d出错: %v", idx, err)
                return
            }
            if ok {
                log.Printf("请求%d: 获得令牌,执行业务", idx)
            } else {
                log.Printf("请求%d: 无令牌,限流拒绝", idx)
            }
        }(i)
        time.Sleep(100 * time.Millisecond)
    }
    wg.Wait()
}