I have implemented a simple WebSocket proxy with FastAPI (using this example)
The application target is to just pass through all messages it gets to its active connections (proxy).
It works well only with a single instance because it keeps active WebSocket connections in memory. And memory is not shared when there is more than one instance.
My naive approach was to solve it by keeping active connections in some shared storage (Redis). But I was stuck with pickling it.
Here is the complete app:
import pickle
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from collections import defaultdict
import redis
app = FastAPI()
rds = redis.StrictRedis('localhost')
class ConnectionManager:
def __init__(self):
self.active_connections = defaultdict(dict)
async def connect(self, websocket: WebSocket, application: str, client_id: str):
await websocket.accept()
if application not in self.active_connections:
self.active_connections[application] = defaultdict(list)
self.active_connections[application][client_id].append(websocket)
#### this is my attempt to store connections ####
rds.set('connections', pickle.dumps(self.active_connections))
def disconnect(self, websocket: WebSocket, application: str, client_id: str):
self.active_connections[application][client_id].remove(websocket)
async def broadcast(self, message: dict, application: str, client_id: str):
for connection in self.active_connections[application][client_id]:
try:
await connection.send_json(message)
print(f"sent {message}")
except Exception as e:
pass
manager = ConnectionManager()
@app.websocket("/ws/channel/{application}/{client_id}/")
async def websocket_endpoint(websocket: WebSocket, application: str, client_id: str):
await manager.connect(websocket, application, client_id)
while True:
try:
data = await websocket.receive_json()
print(f"received: {data}")
await manager.broadcast(data, application, client_id)
except WebSocketDisconnect:
manager.disconnect(websocket, application, client_id)
except RuntimeError:
break
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8005)
However, pickling websocket connection was not successful:
AttributeError: Can't pickle local object 'FastAPI.setup.<locals>.openapi'
What is the proper way to have WebSocket connections stored across the application instances?
UPD The actual solution per @AKX answer.
Each instance of the server is subscribed to Redis pubsub and tries to send the received message to all its connected clients.
Since one client cannot be connected to several instances - each message should be delivered to each client only once
import json
import asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from collections import defaultdict
import redis
app = FastAPI()
rds = redis.StrictRedis('localhost')
class ConnectionManager:
def __init__(self):
self.active_connections = defaultdict(dict)
async def connect(self, websocket: WebSocket, application: str, client_id: str):
await websocket.accept()
if application not in self.active_connections:
self.active_connections[application] = defaultdict(list)
self.active_connections[application][client_id].append(websocket)
def disconnect(self, websocket: WebSocket, application: str, client_id: str):
self.active_connections[application][client_id].remove(websocket)
async def broadcast(self, message: dict, application: str, client_id: str):
for connection in self.active_connections[application][client_id]:
try:
await connection.send_json(message)
print(f"sent {message}")
except Exception as e:
pass
async def consume(self):
print("started to consume")
sub = rds.pubsub()
sub.subscribe('channel')
while True:
await asyncio.sleep(0.01)
message = sub.get_message(ignore_subscribe_messages=True)
if message is not None and isinstance(message, dict):
msg = json.loads(message.get('data'))
await self.broadcast(msg['message'], msg['application'], msg['client_id'])
manager = ConnectionManager()
@app.on_event("startup")
async def subscribe():
asyncio.create_task(manager.consume())
@app.websocket("/ws/channel/{application}/{client_id}/")
async def websocket_endpoint(websocket: WebSocket, application: str, client_id: str):
await manager.connect(websocket, application, client_id)
while True:
try:
data = await websocket.receive_json()
print(f"received: {data}")
rds.publish(
'channel',
json.dumps({
'application': application,
'client_id': client_id,
'message': data
})
)
except WebSocketDisconnect:
manager.disconnect(websocket, application, client_id)
except RuntimeError:
break
if __name__ == '__main__': # pragma: no cover
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8005)