rum-goggles/vendor/github.com/r3labs/sse/v2/client.go
2024-02-23 11:39:16 -05:00

391 lines
10 KiB
Go

/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
package sse
import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"gopkg.in/cenkalti/backoff.v1"
)
var (
headerID = []byte("id:")
headerData = []byte("data:")
headerEvent = []byte("event:")
headerRetry = []byte("retry:")
)
func ClientMaxBufferSize(s int) func(c *Client) {
return func(c *Client) {
c.maxBufferSize = s
}
}
// ConnCallback defines a function to be called on a particular connection event
type ConnCallback func(c *Client)
// ResponseValidator validates a response
type ResponseValidator func(c *Client, resp *http.Response) error
// Client handles an incoming server stream
type Client struct {
Retry time.Time
ReconnectStrategy backoff.BackOff
disconnectcb ConnCallback
connectedcb ConnCallback
subscribed map[chan *Event]chan struct{}
Headers map[string]string
ReconnectNotify backoff.Notify
ResponseValidator ResponseValidator
Connection *http.Client
URL string
LastEventID atomic.Value // []byte
maxBufferSize int
mu sync.Mutex
EncodingBase64 bool
Connected bool
}
// NewClient creates a new client
func NewClient(url string, opts ...func(c *Client)) *Client {
c := &Client{
URL: url,
Connection: &http.Client{},
Headers: make(map[string]string),
subscribed: make(map[chan *Event]chan struct{}),
maxBufferSize: 1 << 16,
}
for _, opt := range opts {
opt(c)
}
return c
}
// Subscribe to a data stream
func (c *Client) Subscribe(stream string, handler func(msg *Event)) error {
return c.SubscribeWithContext(context.Background(), stream, handler)
}
// SubscribeWithContext to a data stream with context
func (c *Client) SubscribeWithContext(ctx context.Context, stream string, handler func(msg *Event)) error {
operation := func() error {
resp, err := c.request(ctx, stream)
if err != nil {
return err
}
if validator := c.ResponseValidator; validator != nil {
err = validator(c, resp)
if err != nil {
return err
}
} else if resp.StatusCode != 200 {
resp.Body.Close()
return fmt.Errorf("could not connect to stream: %s", http.StatusText(resp.StatusCode))
}
defer resp.Body.Close()
reader := NewEventStreamReader(resp.Body, c.maxBufferSize)
eventChan, errorChan := c.startReadLoop(reader)
for {
select {
case err = <-errorChan:
return err
case msg := <-eventChan:
handler(msg)
}
}
}
// Apply user specified reconnection strategy or default to standard NewExponentialBackOff() reconnection method
var err error
if c.ReconnectStrategy != nil {
err = backoff.RetryNotify(operation, c.ReconnectStrategy, c.ReconnectNotify)
} else {
err = backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), c.ReconnectNotify)
}
return err
}
// SubscribeChan sends all events to the provided channel
func (c *Client) SubscribeChan(stream string, ch chan *Event) error {
return c.SubscribeChanWithContext(context.Background(), stream, ch)
}
// SubscribeChanWithContext sends all events to the provided channel with context
func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch chan *Event) error {
var connected bool
errch := make(chan error)
c.mu.Lock()
c.subscribed[ch] = make(chan struct{})
c.mu.Unlock()
operation := func() error {
resp, err := c.request(ctx, stream)
if err != nil {
return err
}
if validator := c.ResponseValidator; validator != nil {
err = validator(c, resp)
if err != nil {
return err
}
} else if resp.StatusCode != 200 {
resp.Body.Close()
return fmt.Errorf("could not connect to stream: %s", http.StatusText(resp.StatusCode))
}
defer resp.Body.Close()
if !connected {
// Notify connect
errch <- nil
connected = true
}
reader := NewEventStreamReader(resp.Body, c.maxBufferSize)
eventChan, errorChan := c.startReadLoop(reader)
for {
var msg *Event
// Wait for message to arrive or exit
select {
case <-c.subscribed[ch]:
return nil
case err = <-errorChan:
return err
case msg = <-eventChan:
}
// Wait for message to be sent or exit
if msg != nil {
select {
case <-c.subscribed[ch]:
return nil
case ch <- msg:
// message sent
}
}
}
}
go func() {
defer c.cleanup(ch)
// Apply user specified reconnection strategy or default to standard NewExponentialBackOff() reconnection method
var err error
if c.ReconnectStrategy != nil {
err = backoff.RetryNotify(operation, c.ReconnectStrategy, c.ReconnectNotify)
} else {
err = backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), c.ReconnectNotify)
}
// channel closed once connected
if err != nil && !connected {
errch <- err
}
}()
err := <-errch
close(errch)
return err
}
func (c *Client) startReadLoop(reader *EventStreamReader) (chan *Event, chan error) {
outCh := make(chan *Event)
erChan := make(chan error)
go c.readLoop(reader, outCh, erChan)
return outCh, erChan
}
func (c *Client) readLoop(reader *EventStreamReader, outCh chan *Event, erChan chan error) {
for {
// Read each new line and process the type of event
event, err := reader.ReadEvent()
if err != nil {
if err == io.EOF {
erChan <- nil
return
}
// run user specified disconnect function
if c.disconnectcb != nil {
c.Connected = false
c.disconnectcb(c)
}
erChan <- err
return
}
if !c.Connected && c.connectedcb != nil {
c.Connected = true
c.connectedcb(c)
}
// If we get an error, ignore it.
var msg *Event
if msg, err = c.processEvent(event); err == nil {
if len(msg.ID) > 0 {
c.LastEventID.Store(msg.ID)
} else {
msg.ID, _ = c.LastEventID.Load().([]byte)
}
// Send downstream if the event has something useful
if msg.hasContent() {
outCh <- msg
}
}
}
}
// SubscribeRaw to an sse endpoint
func (c *Client) SubscribeRaw(handler func(msg *Event)) error {
return c.Subscribe("", handler)
}
// SubscribeRawWithContext to an sse endpoint with context
func (c *Client) SubscribeRawWithContext(ctx context.Context, handler func(msg *Event)) error {
return c.SubscribeWithContext(ctx, "", handler)
}
// SubscribeChanRaw sends all events to the provided channel
func (c *Client) SubscribeChanRaw(ch chan *Event) error {
return c.SubscribeChan("", ch)
}
// SubscribeChanRawWithContext sends all events to the provided channel with context
func (c *Client) SubscribeChanRawWithContext(ctx context.Context, ch chan *Event) error {
return c.SubscribeChanWithContext(ctx, "", ch)
}
// Unsubscribe unsubscribes a channel
func (c *Client) Unsubscribe(ch chan *Event) {
c.mu.Lock()
defer c.mu.Unlock()
if c.subscribed[ch] != nil {
c.subscribed[ch] <- struct{}{}
}
}
// OnDisconnect specifies the function to run when the connection disconnects
func (c *Client) OnDisconnect(fn ConnCallback) {
c.disconnectcb = fn
}
// OnConnect specifies the function to run when the connection is successful
func (c *Client) OnConnect(fn ConnCallback) {
c.connectedcb = fn
}
func (c *Client) request(ctx context.Context, stream string) (*http.Response, error) {
req, err := http.NewRequest("GET", c.URL, nil)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
// Setup request, specify stream to connect to
if stream != "" {
query := req.URL.Query()
query.Add("stream", stream)
req.URL.RawQuery = query.Encode()
}
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Connection", "keep-alive")
lastID, exists := c.LastEventID.Load().([]byte)
if exists && lastID != nil {
req.Header.Set("Last-Event-ID", string(lastID))
}
// Add user specified headers
for k, v := range c.Headers {
req.Header.Set(k, v)
}
return c.Connection.Do(req)
}
func (c *Client) processEvent(msg []byte) (event *Event, err error) {
var e Event
if len(msg) < 1 {
return nil, errors.New("event message was empty")
}
// Normalize the crlf to lf to make it easier to split the lines.
// Split the line by "\n" or "\r", per the spec.
for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) {
switch {
case bytes.HasPrefix(line, headerID):
e.ID = append([]byte(nil), trimHeader(len(headerID), line)...)
case bytes.HasPrefix(line, headerData):
// The spec allows for multiple data fields per event, concatenated them with "\n".
e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...)
// The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body.
case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))):
e.Data = append(e.Data, byte('\n'))
case bytes.HasPrefix(line, headerEvent):
e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...)
case bytes.HasPrefix(line, headerRetry):
e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...)
default:
// Ignore any garbage that doesn't match what we're looking for.
}
}
// Trim the last "\n" per the spec.
e.Data = bytes.TrimSuffix(e.Data, []byte("\n"))
if c.EncodingBase64 {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data)))
n, err := base64.StdEncoding.Decode(buf, e.Data)
if err != nil {
err = fmt.Errorf("failed to decode event message: %s", err)
}
e.Data = buf[:n]
}
return &e, err
}
func (c *Client) cleanup(ch chan *Event) {
c.mu.Lock()
defer c.mu.Unlock()
if c.subscribed[ch] != nil {
close(c.subscribed[ch])
delete(c.subscribed, ch)
}
}
func trimHeader(size int, data []byte) []byte {
if data == nil || len(data) < size {
return data
}
data = data[size:]
// Remove optional leading whitespace
if len(data) > 0 && data[0] == 32 {
data = data[1:]
}
// Remove trailing new line
if len(data) > 0 && data[len(data)-1] == 10 {
data = data[:len(data)-1]
}
return data
}