391 lines
10 KiB
Go
391 lines
10 KiB
Go
/* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
|
|
|
package sse
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"gopkg.in/cenkalti/backoff.v1"
|
|
)
|
|
|
|
var (
|
|
headerID = []byte("id:")
|
|
headerData = []byte("data:")
|
|
headerEvent = []byte("event:")
|
|
headerRetry = []byte("retry:")
|
|
)
|
|
|
|
func ClientMaxBufferSize(s int) func(c *Client) {
|
|
return func(c *Client) {
|
|
c.maxBufferSize = s
|
|
}
|
|
}
|
|
|
|
// ConnCallback defines a function to be called on a particular connection event
|
|
type ConnCallback func(c *Client)
|
|
|
|
// ResponseValidator validates a response
|
|
type ResponseValidator func(c *Client, resp *http.Response) error
|
|
|
|
// Client handles an incoming server stream
|
|
type Client struct {
|
|
Retry time.Time
|
|
ReconnectStrategy backoff.BackOff
|
|
disconnectcb ConnCallback
|
|
connectedcb ConnCallback
|
|
subscribed map[chan *Event]chan struct{}
|
|
Headers map[string]string
|
|
ReconnectNotify backoff.Notify
|
|
ResponseValidator ResponseValidator
|
|
Connection *http.Client
|
|
URL string
|
|
LastEventID atomic.Value // []byte
|
|
maxBufferSize int
|
|
mu sync.Mutex
|
|
EncodingBase64 bool
|
|
Connected bool
|
|
}
|
|
|
|
// NewClient creates a new client
|
|
func NewClient(url string, opts ...func(c *Client)) *Client {
|
|
c := &Client{
|
|
URL: url,
|
|
Connection: &http.Client{},
|
|
Headers: make(map[string]string),
|
|
subscribed: make(map[chan *Event]chan struct{}),
|
|
maxBufferSize: 1 << 16,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(c)
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
// Subscribe to a data stream
|
|
func (c *Client) Subscribe(stream string, handler func(msg *Event)) error {
|
|
return c.SubscribeWithContext(context.Background(), stream, handler)
|
|
}
|
|
|
|
// SubscribeWithContext to a data stream with context
|
|
func (c *Client) SubscribeWithContext(ctx context.Context, stream string, handler func(msg *Event)) error {
|
|
operation := func() error {
|
|
resp, err := c.request(ctx, stream)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if validator := c.ResponseValidator; validator != nil {
|
|
err = validator(c, resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if resp.StatusCode != 200 {
|
|
resp.Body.Close()
|
|
return fmt.Errorf("could not connect to stream: %s", http.StatusText(resp.StatusCode))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
reader := NewEventStreamReader(resp.Body, c.maxBufferSize)
|
|
eventChan, errorChan := c.startReadLoop(reader)
|
|
|
|
for {
|
|
select {
|
|
case err = <-errorChan:
|
|
return err
|
|
case msg := <-eventChan:
|
|
handler(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Apply user specified reconnection strategy or default to standard NewExponentialBackOff() reconnection method
|
|
var err error
|
|
if c.ReconnectStrategy != nil {
|
|
err = backoff.RetryNotify(operation, c.ReconnectStrategy, c.ReconnectNotify)
|
|
} else {
|
|
err = backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), c.ReconnectNotify)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// SubscribeChan sends all events to the provided channel
|
|
func (c *Client) SubscribeChan(stream string, ch chan *Event) error {
|
|
return c.SubscribeChanWithContext(context.Background(), stream, ch)
|
|
}
|
|
|
|
// SubscribeChanWithContext sends all events to the provided channel with context
|
|
func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch chan *Event) error {
|
|
var connected bool
|
|
errch := make(chan error)
|
|
c.mu.Lock()
|
|
c.subscribed[ch] = make(chan struct{})
|
|
c.mu.Unlock()
|
|
|
|
operation := func() error {
|
|
resp, err := c.request(ctx, stream)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if validator := c.ResponseValidator; validator != nil {
|
|
err = validator(c, resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if resp.StatusCode != 200 {
|
|
resp.Body.Close()
|
|
return fmt.Errorf("could not connect to stream: %s", http.StatusText(resp.StatusCode))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if !connected {
|
|
// Notify connect
|
|
errch <- nil
|
|
connected = true
|
|
}
|
|
|
|
reader := NewEventStreamReader(resp.Body, c.maxBufferSize)
|
|
eventChan, errorChan := c.startReadLoop(reader)
|
|
|
|
for {
|
|
var msg *Event
|
|
// Wait for message to arrive or exit
|
|
select {
|
|
case <-c.subscribed[ch]:
|
|
return nil
|
|
case err = <-errorChan:
|
|
return err
|
|
case msg = <-eventChan:
|
|
}
|
|
|
|
// Wait for message to be sent or exit
|
|
if msg != nil {
|
|
select {
|
|
case <-c.subscribed[ch]:
|
|
return nil
|
|
case ch <- msg:
|
|
// message sent
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
defer c.cleanup(ch)
|
|
// Apply user specified reconnection strategy or default to standard NewExponentialBackOff() reconnection method
|
|
var err error
|
|
if c.ReconnectStrategy != nil {
|
|
err = backoff.RetryNotify(operation, c.ReconnectStrategy, c.ReconnectNotify)
|
|
} else {
|
|
err = backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), c.ReconnectNotify)
|
|
}
|
|
|
|
// channel closed once connected
|
|
if err != nil && !connected {
|
|
errch <- err
|
|
}
|
|
}()
|
|
err := <-errch
|
|
close(errch)
|
|
return err
|
|
}
|
|
|
|
func (c *Client) startReadLoop(reader *EventStreamReader) (chan *Event, chan error) {
|
|
outCh := make(chan *Event)
|
|
erChan := make(chan error)
|
|
go c.readLoop(reader, outCh, erChan)
|
|
return outCh, erChan
|
|
}
|
|
|
|
func (c *Client) readLoop(reader *EventStreamReader, outCh chan *Event, erChan chan error) {
|
|
for {
|
|
// Read each new line and process the type of event
|
|
event, err := reader.ReadEvent()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
erChan <- nil
|
|
return
|
|
}
|
|
// run user specified disconnect function
|
|
if c.disconnectcb != nil {
|
|
c.Connected = false
|
|
c.disconnectcb(c)
|
|
}
|
|
erChan <- err
|
|
return
|
|
}
|
|
|
|
if !c.Connected && c.connectedcb != nil {
|
|
c.Connected = true
|
|
c.connectedcb(c)
|
|
}
|
|
|
|
// If we get an error, ignore it.
|
|
var msg *Event
|
|
if msg, err = c.processEvent(event); err == nil {
|
|
if len(msg.ID) > 0 {
|
|
c.LastEventID.Store(msg.ID)
|
|
} else {
|
|
msg.ID, _ = c.LastEventID.Load().([]byte)
|
|
}
|
|
|
|
// Send downstream if the event has something useful
|
|
if msg.hasContent() {
|
|
outCh <- msg
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SubscribeRaw to an sse endpoint
|
|
func (c *Client) SubscribeRaw(handler func(msg *Event)) error {
|
|
return c.Subscribe("", handler)
|
|
}
|
|
|
|
// SubscribeRawWithContext to an sse endpoint with context
|
|
func (c *Client) SubscribeRawWithContext(ctx context.Context, handler func(msg *Event)) error {
|
|
return c.SubscribeWithContext(ctx, "", handler)
|
|
}
|
|
|
|
// SubscribeChanRaw sends all events to the provided channel
|
|
func (c *Client) SubscribeChanRaw(ch chan *Event) error {
|
|
return c.SubscribeChan("", ch)
|
|
}
|
|
|
|
// SubscribeChanRawWithContext sends all events to the provided channel with context
|
|
func (c *Client) SubscribeChanRawWithContext(ctx context.Context, ch chan *Event) error {
|
|
return c.SubscribeChanWithContext(ctx, "", ch)
|
|
}
|
|
|
|
// Unsubscribe unsubscribes a channel
|
|
func (c *Client) Unsubscribe(ch chan *Event) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.subscribed[ch] != nil {
|
|
c.subscribed[ch] <- struct{}{}
|
|
}
|
|
}
|
|
|
|
// OnDisconnect specifies the function to run when the connection disconnects
|
|
func (c *Client) OnDisconnect(fn ConnCallback) {
|
|
c.disconnectcb = fn
|
|
}
|
|
|
|
// OnConnect specifies the function to run when the connection is successful
|
|
func (c *Client) OnConnect(fn ConnCallback) {
|
|
c.connectedcb = fn
|
|
}
|
|
|
|
func (c *Client) request(ctx context.Context, stream string) (*http.Response, error) {
|
|
req, err := http.NewRequest("GET", c.URL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req = req.WithContext(ctx)
|
|
|
|
// Setup request, specify stream to connect to
|
|
if stream != "" {
|
|
query := req.URL.Query()
|
|
query.Add("stream", stream)
|
|
req.URL.RawQuery = query.Encode()
|
|
}
|
|
|
|
req.Header.Set("Cache-Control", "no-cache")
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
req.Header.Set("Connection", "keep-alive")
|
|
|
|
lastID, exists := c.LastEventID.Load().([]byte)
|
|
if exists && lastID != nil {
|
|
req.Header.Set("Last-Event-ID", string(lastID))
|
|
}
|
|
|
|
// Add user specified headers
|
|
for k, v := range c.Headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
return c.Connection.Do(req)
|
|
}
|
|
|
|
func (c *Client) processEvent(msg []byte) (event *Event, err error) {
|
|
var e Event
|
|
|
|
if len(msg) < 1 {
|
|
return nil, errors.New("event message was empty")
|
|
}
|
|
|
|
// Normalize the crlf to lf to make it easier to split the lines.
|
|
// Split the line by "\n" or "\r", per the spec.
|
|
for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) {
|
|
switch {
|
|
case bytes.HasPrefix(line, headerID):
|
|
e.ID = append([]byte(nil), trimHeader(len(headerID), line)...)
|
|
case bytes.HasPrefix(line, headerData):
|
|
// The spec allows for multiple data fields per event, concatenated them with "\n".
|
|
e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...)
|
|
// The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body.
|
|
case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))):
|
|
e.Data = append(e.Data, byte('\n'))
|
|
case bytes.HasPrefix(line, headerEvent):
|
|
e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...)
|
|
case bytes.HasPrefix(line, headerRetry):
|
|
e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...)
|
|
default:
|
|
// Ignore any garbage that doesn't match what we're looking for.
|
|
}
|
|
}
|
|
|
|
// Trim the last "\n" per the spec.
|
|
e.Data = bytes.TrimSuffix(e.Data, []byte("\n"))
|
|
|
|
if c.EncodingBase64 {
|
|
buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data)))
|
|
|
|
n, err := base64.StdEncoding.Decode(buf, e.Data)
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to decode event message: %s", err)
|
|
}
|
|
e.Data = buf[:n]
|
|
}
|
|
return &e, err
|
|
}
|
|
|
|
func (c *Client) cleanup(ch chan *Event) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.subscribed[ch] != nil {
|
|
close(c.subscribed[ch])
|
|
delete(c.subscribed, ch)
|
|
}
|
|
}
|
|
|
|
func trimHeader(size int, data []byte) []byte {
|
|
if data == nil || len(data) < size {
|
|
return data
|
|
}
|
|
|
|
data = data[size:]
|
|
// Remove optional leading whitespace
|
|
if len(data) > 0 && data[0] == 32 {
|
|
data = data[1:]
|
|
}
|
|
// Remove trailing new line
|
|
if len(data) > 0 && data[len(data)-1] == 10 {
|
|
data = data[:len(data)-1]
|
|
}
|
|
return data
|
|
}
|