How do I enable TLS on an already connected Python asyncio stream?
Asked Answered
C

4

6

I have a Python asyncio server written using the high-level Streams API. I want to enable TLS on an already established connection, as in STARTTLS in the SMTP and IMAP protocols. The asyncio event loop has a start_tls() function (added in Python 3.7), but it takes a protocol and a transport rather than a stream. The streams API does let you get the transport via StreamWriter.transport. I don't see a way to change the transport, which would be required after calling start_tls().

Is it possible to use start_tls() with the streams API?

Catanddog answered 11/7, 2020 at 15:55 Comment(1)
I've coded up the solutions posted here. Obviously it's a rubbish solution as it depends on private variables. Maybe something like it could be added to the standard library.Bioenergetics
L
4

As of python 3.11, this is as simple as calling StreamWriter.start_tls https://docs.python.org/3/library/asyncio-stream.html#asyncio.StreamWriter.start_tls

E.g. for a client:

import asyncio
import ssl
reader, writer = await asyncio.open_connection('your_server_uri_here', port)
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
# Adjust the ssl context for your particular needs (e.g. self signed certificate or what not)
await writer.start_tls(context)

and for a server, you would add the upgrade code to your handler

import asyncio
from asyncio.streams import StreamReader, StreamWriter
import ssl

# Modify to match your needs
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
context.load_cert_chain("path/to/cert", "path/to/key")

def handler(reader: StreamReader, writer: StreamWriter):
  await writer.start_tls(context)
  # Do stuff

await asyncio.start_server(handler, "server hostname", 8000)
Labiodental answered 19/8, 2023 at 2:39 Comment(0)
C
4

Looking at the code for the streams API you'll notice that StreamReader and StreamWriter both store their transport in an internal _transport variable. It turns out that if you call start_tls() and then store the new transport in those variables it works just fine. All the usual caveats with using an internal API apply of course. Here's what this looks like for a server. On a client I think you can just drop the load_cert_chain and server_side bits.

transport = writer.transport
protocol = transport.get_protocol()
loop = asyncio.get_event_loop()
ssl_context = ssl.SSLContext()
ssl_context.load_cert_chain("/path/to/certchain", "/path/to/key")
new_transport = await loop.start_tls(
    transport, protocol, ssl_context, server_side=True)
writer._transport = new_transport
reader._transport = new_transport

# prevent warning "returning true from eof_received() has no effect when using ssl"
protocol._over_ssl = True
Catanddog answered 11/7, 2020 at 15:55 Comment(11)
This approach (relying on an internal attribute) can stop working at any time, even in a bugfix release, and should never be used for production code.Cimex
@Cimex well then please show us HOW to properly implement TLS upgrade without using internal attributes.Airframe
@Airframe That's a non-sequitur and you know it. Not only is an internal attribute being used here, it is being assigned a different transport object without taking care that the old one is properly detached. People often google things on SO and copy-paste snippets into their code; for this snippet it only makes ample sense to warn the readers not to do so.Cimex
@Cimex it's not a "non-sequitur". Yes, agree with the sentiments that one must be careful using internal attributes, but "should never be used for production code" is too strong. The fact is, before 3.7, asyncio did not provide a way to implement TLS upgrade without using internal attributes, and someone who needs to do that has no options. Heck, even in 3.7+ the start_tls() method is not documented properly.Airframe
@Airframe It's a non-sequitur because pointing out a serious flaw in an answer doesn't imply an obligation to come up with the correct solution. Your final argument that the wording of my comment was perhaps too strong has merit and is not non-sequitur (even if I disagree), but that's not how you initiated the exchange.Cimex
@Cimex My comment was not a "non-sequitur" because my comment is directly related to the Original Question. Without context of the Original Question, your comment is totally correct, and I agree, my response would be irrelevant. However, because of the context of the Original Question, plus the Needs of the Person Posting the Question, my response becomes totally relevant. You raised a strong, all-encompassing statement "should NEVER be used", while not providing a solution to Questioner's Needs.Airframe
@Airframe I already explained why your comment was non-sequitur, so I won't repeat it. You don't add anything new to the discussion, you just repeat what you previously said about too strong a wording, which I agree is a valid criticism regardless of whether I agree with it.Cimex
@Cimex I added a new solution, what you think about it?Lotuseater
@Lotuseater Unless I am misreading it, your solution doesn't upgrade am already existing connection, which is what the question asked for.Cimex
@Cimex it upgrades but actual StreamReader and StreamWriter need to be replaced, I use that solution to implement HTTP proxy support and for that I need to establish a connection (without TLS) to proxy server, then upgrade it to TLS with target host. The underlying socket is the same.Lotuseater
@hldevv I get that, but I also understand the OP to hold an existing stream that needs to be upgraded. Your code both creates and upgrades the stream.Cimex
L
4

As of python 3.11, this is as simple as calling StreamWriter.start_tls https://docs.python.org/3/library/asyncio-stream.html#asyncio.StreamWriter.start_tls

E.g. for a client:

import asyncio
import ssl
reader, writer = await asyncio.open_connection('your_server_uri_here', port)
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
# Adjust the ssl context for your particular needs (e.g. self signed certificate or what not)
await writer.start_tls(context)

and for a server, you would add the upgrade code to your handler

import asyncio
from asyncio.streams import StreamReader, StreamWriter
import ssl

# Modify to match your needs
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
context.load_cert_chain("path/to/cert", "path/to/key")

def handler(reader: StreamReader, writer: StreamWriter):
  await writer.start_tls(context)
  # Do stuff

await asyncio.start_server(handler, "server hostname", 8000)
Labiodental answered 19/8, 2023 at 2:39 Comment(0)
L
2

I needed to implement proxy support for asyncio streams of Python 3.8 and came up with following solution:

import socket
import weakref
import asyncio
import typing as t
from ssl import create_default_context, Purpose, SSLContext


class TLSStreamReaderProtocol(asyncio.StreamReaderProtocol):

    def upgrade_reader(self):
        if self._stream_reader is not None:
            self._stream_reader.set_exception(Exception('upgraded connection to TLS, this reader is obsolete now.'))
        self._stream_reader_wr = weakref.ref(reader)
        self._source_traceback = reader._source_traceback


async def open_tls_stream(host: str, port: int, ssl: t.Union[SSLContext, bool]=False):
    # this does the same as loop.open_connection(), but TLS upgrade is done
    # manually after connection be established.
    loop = asyncio.get_running_loop()
    reader = asyncio.StreamReader(limit=2**64, loop=loop)
    protocol = TLSStreamReaderProtocol(reader, loop=loop)
    transport, _ = await loop.create_connection(
        lambda: protocol, host, port, family=socket.AF_INET
    )
    writer = asyncio.StreamWriter(transport, protocol, reader, loop)
    # here you can use reader and writer for whatever you want, for example
    # start a proxy connection and start TLS to target host later...
    # now perform TLS upgrade
    if ssl:
        transport = await loop.start_tls(
            transport,
            protocol,
            sslcontext=create_default_context(Purpose.SERVER_AUTH) if isinstance(ssl, bool) else ssl,
            server_side=False,
            server_hostname=host
        )
        reader = asyncio.StreamReader(limit=2**64, loop=loop)
        protocol.upgrade_reader(reader) # update reader
        protocol.connection_made(transport) # update transport
        writer = asyncio.StreamWriter(transport, protocol, reader, loop) # update writer
    return reader, writer     
Lotuseater answered 7/8, 2021 at 21:25 Comment(0)
W
1

I'm using the following:

import asyncio
import ssl
from typing import Optional


async def tls_handshake(
    reader: asyncio.StreamReader,
    writer: asyncio.StreamWriter,
    ssl_context: Optional[ssl.SSLContext] = None,
    server_side: bool = False,
):
    """
    Manually perform a TLS handshake over a stream.

    Args:
        reader: The reader of the client connection.
        writer: The writer of the client connection.
        ssl_context: The SSL context to use for the TLS/SSL handshake. Defaults to None.
        server_side: Whether the connection is server-side or not.

    Note:
        If the ssl context is not passed and the connection is not server_side
        `ssl.create_default_context()` will be used.

        For Python 3.6 to 3.9 you can use ``ssl.PROTOCOL_TLS`` for the SSL context. For
        Python 3.10+ you need to either use ``ssl.PROTOCOL_TLS_CLIENT`` or
        ``ssl.PROTOCOL_TLS_SERVER`` depending on the role of the reader/writer.
    """

    if not server_side and not ssl_context:
        ssl_context = ssl.create_default_context()

    transport = writer.transport
    protocol = transport.get_protocol()

    loop = asyncio.get_event_loop()
    new_transport = await loop.start_tls(
        transport=transport,
        protocol=protocol,
        sslcontext=ssl_context,
        server_side=server_side,
    )

    reader._transport = new_transport
    writer._transport = new_transport

I've added the above function to toolbox as well - see here.

Here is an example of manually performing the TLS handshake on connection to client:

import asyncio
import ssl

from toolbox.asyncio.stream import tls_handshake

async def client():
    reader, writer = await asyncio.open_connection("httpbin.org", 443, ssl=False)
    await tls_handshake(reader=reader, writer=writer)
    # Communication is now encrypted.

asyncio.run(client())

And here is a more advance example of a server & client:

import asyncio
import ssl

from toolbox.asyncio.stream import tls_handshake

HOST = "127.0.0.1"
PORT = 8888
CERT = "server.crt"
KEY = "server.key"

async def server_stream(reader, writer):
    # Perform TLS handshake before sending/receiving data.
    context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    context.load_cert_chain(CERT, KEY)
    await tls_handshake(
        reader=reader,
        writer=writer,
        ssl_context=context,
        server_side=True,
    )

    # Receive from client.
    data = await reader.read(1024)
    print("Received from client:", data)

    # Send from server,
    writer.write(b"Server here.")
    await writer.drain()

    # Closes the connection server-side.
    writer.close()
    await writer.wait_closed()

async def client():
    # Open connection to server.
    reader, writer = await asyncio.open_connection(host=HOST, port=PORT)

    # Perform TLS handshake before sending/receiving data.
    context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
    context.load_verify_locations(CERT)
    await tls_handshake(
        reader=reader,
        writer=writer,
        ssl_context=context,
    )

    # Send from client.
    writer.write(b"Client here.")
    await writer.drain()

    # Receive from server.
    data = await reader.read(1024)
    print("Received from server:", data)

    # Closes the connection client-side.
    writer.close()
    await writer.wait_closed()

async def main():
    server = await asyncio.start_server(server_stream, host=HOST, port=PORT)
    await client()
    async with server:
        await server.serve_forever()

asyncio.run(main())
Wittgenstein answered 28/11, 2021 at 17:9 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.