I'm using the following:
import asyncio
import ssl
from typing import Optional
async def tls_handshake(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
ssl_context: Optional[ssl.SSLContext] = None,
server_side: bool = False,
):
"""
Manually perform a TLS handshake over a stream.
Args:
reader: The reader of the client connection.
writer: The writer of the client connection.
ssl_context: The SSL context to use for the TLS/SSL handshake. Defaults to None.
server_side: Whether the connection is server-side or not.
Note:
If the ssl context is not passed and the connection is not server_side
`ssl.create_default_context()` will be used.
For Python 3.6 to 3.9 you can use ``ssl.PROTOCOL_TLS`` for the SSL context. For
Python 3.10+ you need to either use ``ssl.PROTOCOL_TLS_CLIENT`` or
``ssl.PROTOCOL_TLS_SERVER`` depending on the role of the reader/writer.
"""
if not server_side and not ssl_context:
ssl_context = ssl.create_default_context()
transport = writer.transport
protocol = transport.get_protocol()
loop = asyncio.get_event_loop()
new_transport = await loop.start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_side=server_side,
)
reader._transport = new_transport
writer._transport = new_transport
I've added the above function to toolbox
as well - see here.
Here is an example of manually performing the TLS handshake on connection to client:
import asyncio
import ssl
from toolbox.asyncio.stream import tls_handshake
async def client():
reader, writer = await asyncio.open_connection("httpbin.org", 443, ssl=False)
await tls_handshake(reader=reader, writer=writer)
# Communication is now encrypted.
asyncio.run(client())
And here is a more advance example of a server & client:
import asyncio
import ssl
from toolbox.asyncio.stream import tls_handshake
HOST = "127.0.0.1"
PORT = 8888
CERT = "server.crt"
KEY = "server.key"
async def server_stream(reader, writer):
# Perform TLS handshake before sending/receiving data.
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(CERT, KEY)
await tls_handshake(
reader=reader,
writer=writer,
ssl_context=context,
server_side=True,
)
# Receive from client.
data = await reader.read(1024)
print("Received from client:", data)
# Send from server,
writer.write(b"Server here.")
await writer.drain()
# Closes the connection server-side.
writer.close()
await writer.wait_closed()
async def client():
# Open connection to server.
reader, writer = await asyncio.open_connection(host=HOST, port=PORT)
# Perform TLS handshake before sending/receiving data.
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.load_verify_locations(CERT)
await tls_handshake(
reader=reader,
writer=writer,
ssl_context=context,
)
# Send from client.
writer.write(b"Client here.")
await writer.drain()
# Receive from server.
data = await reader.read(1024)
print("Received from server:", data)
# Closes the connection client-side.
writer.close()
await writer.wait_closed()
async def main():
server = await asyncio.start_server(server_stream, host=HOST, port=PORT)
await client()
async with server:
await server.serve_forever()
asyncio.run(main())