FastAPI WebSocket replication
Asked Answered
D

1

3

I have implemented a simple WebSocket proxy with FastAPI (using this example)

The application target is to just pass through all messages it gets to its active connections (proxy).

It works well only with a single instance because it keeps active WebSocket connections in memory. And memory is not shared when there is more than one instance.

My naive approach was to solve it by keeping active connections in some shared storage (Redis). But I was stuck with pickling it.

Here is the complete app:

import pickle

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from collections import defaultdict
import redis

app = FastAPI()
rds = redis.StrictRedis('localhost')

class ConnectionManager:
    def __init__(self):
        self.active_connections = defaultdict(dict)

    async def connect(self, websocket: WebSocket, application: str, client_id: str):
        await websocket.accept()
        if application not in self.active_connections:
            self.active_connections[application] = defaultdict(list)

        self.active_connections[application][client_id].append(websocket)

        #### this is my attempt to store connections ####
        rds.set('connections', pickle.dumps(self.active_connections)) 

    def disconnect(self, websocket: WebSocket, application: str, client_id: str):
        self.active_connections[application][client_id].remove(websocket)

    async def broadcast(self, message: dict, application: str, client_id: str):
        for connection in self.active_connections[application][client_id]:
            try:
                await connection.send_json(message)
                print(f"sent {message}")
            except Exception as e:
                pass


manager = ConnectionManager()


@app.websocket("/ws/channel/{application}/{client_id}/")
async def websocket_endpoint(websocket: WebSocket, application: str, client_id: str):
    await manager.connect(websocket, application, client_id)
    while True:
        try:
            data = await websocket.receive_json()
            print(f"received: {data}")
            await manager.broadcast(data, application, client_id)
        except WebSocketDisconnect:
            manager.disconnect(websocket, application, client_id)
        except RuntimeError:
            break


if __name__ == '__main__':
    import uvicorn

    uvicorn.run(app, host='0.0.0.0', port=8005)

However, pickling websocket connection was not successful:

AttributeError: Can't pickle local object 'FastAPI.setup.<locals>.openapi'

What is the proper way to have WebSocket connections stored across the application instances?

UPD The actual solution per @AKX answer.

Each instance of the server is subscribed to Redis pubsub and tries to send the received message to all its connected clients.

Since one client cannot be connected to several instances - each message should be delivered to each client only once

import json
import asyncio


from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from collections import defaultdict
import redis

app = FastAPI()
rds = redis.StrictRedis('localhost')


class ConnectionManager:
    def __init__(self):
        self.active_connections = defaultdict(dict)

    async def connect(self, websocket: WebSocket, application: str, client_id: str):
        await websocket.accept()
        if application not in self.active_connections:
            self.active_connections[application] = defaultdict(list)

        self.active_connections[application][client_id].append(websocket)

    def disconnect(self, websocket: WebSocket, application: str, client_id: str):
        self.active_connections[application][client_id].remove(websocket)

    async def broadcast(self, message: dict, application: str, client_id: str):
        for connection in self.active_connections[application][client_id]:
            try:
                await connection.send_json(message)
                print(f"sent {message}")
            except Exception as e:
                pass

    async def consume(self):
        print("started to consume")
        sub = rds.pubsub()
        sub.subscribe('channel')
        while True:
            await asyncio.sleep(0.01)
            message = sub.get_message(ignore_subscribe_messages=True)
            if message is not None and isinstance(message, dict):
                msg = json.loads(message.get('data'))
                await self.broadcast(msg['message'], msg['application'], msg['client_id'])


manager = ConnectionManager()


@app.on_event("startup")
async def subscribe():
    asyncio.create_task(manager.consume())


@app.websocket("/ws/channel/{application}/{client_id}/")
async def websocket_endpoint(websocket: WebSocket, application: str, client_id: str):
    await manager.connect(websocket, application, client_id)
    while True:
        try:
            data = await websocket.receive_json()
            print(f"received: {data}")
            rds.publish(
                'channel',
                json.dumps({
                   'application': application,
                   'client_id': client_id,
                   'message': data
                })
            )
        except WebSocketDisconnect:
            manager.disconnect(websocket, application, client_id)
        except RuntimeError:
            break


if __name__ == '__main__':  # pragma: no cover
    import uvicorn

    uvicorn.run(app, host='0.0.0.0', port=8005)
Doubletalk answered 12/2, 2022 at 10:10 Comment(2)
No. To the connected client are relevant only the messages that are sent after the moment of connection.Doubletalk
@Chris, yes. And the potential for high load. I think pubsub suggested by AKX is a good way. Need only to figure out the implementation detailsDoubletalk
L
3

What is the proper way to have WebSocket connections stored across the application instances?

There is no practical way for multiple processes to share websocket connections that I know of. As you noticed, you can't pickle the connections (not least because you can't pickle the actual OS-level file descriptor that represents the network connection). You can send file descriptors to other processes with some POSIX magic, but even so, you'd also need to make sure the processes know the state of the websocket and don't e.g. race at sending or receiving data.

I would probably redesign things to have a single process that manages the websocket connections and e.g. uses Redis (since you already have it) pubsub or streams to communicate with other multiple processes you have.

Leafy answered 12/2, 2022 at 10:17 Comment(1)
Thanks! pubsub approach seem to solve this. I've updated question with the actual solution for thatDoubletalk

© 2022 - 2024 — McMap. All rights reserved.