diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py index 569157b277..846bece39b 100644 --- a/django/core/handlers/asgi.py +++ b/django/core/handlers/asgi.py @@ -1,3 +1,4 @@ +import asyncio import logging import sys import tempfile @@ -177,15 +178,49 @@ class ASGIHandler(base.BaseHandler): body_file.close() await self.send_response(error_response, send) return - # Get the response, using the async mode of BaseHandler. + # Try to catch a disconnect while getting response. + tasks = [ + asyncio.create_task(self.run_get_response(request)), + asyncio.create_task(self.listen_for_disconnect(receive)), + ] + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + done, pending = done.pop(), pending.pop() + # Allow views to handle cancellation. + pending.cancel() + try: + await pending + except asyncio.CancelledError: + # Task re-raised the CancelledError as expected. + pass + try: + response = done.result() + except RequestAborted: + body_file.close() + return + except AssertionError: + body_file.close() + raise + # Send the response. + await self.send_response(response, send) + + async def listen_for_disconnect(self, receive): + """Listen for disconnect from the client.""" + message = await receive() + if message["type"] == "http.disconnect": + raise RequestAborted() + # This should never happen. + assert False, "Invalid ASGI message after request body: %s" % message["type"] + + async def run_get_response(self, request): + """Get async response.""" + # Use the async mode of BaseHandler. response = await self.get_response_async(request) response._handler_class = self.__class__ # Increase chunk size on file responses (ASGI servers handles low-level # chunking). if isinstance(response, FileResponse): response.block_size = self.chunk_size - # Send the response. - await self.send_response(response, send) + return response async def read_body(self, receive): """Reads an HTTP body from an ASGI connection.""" diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 6e911f471e..cc4fb69ee3 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -192,6 +192,13 @@ Minor features * ... +Asynchronous views +~~~~~~~~~~~~~~~~~~ + +* Under ASGI, ``http.disconnect`` events are now handled. This allows views to + perform any necessary cleanup if a client disconnects before the response is + generated. See :ref:`async-handling-disconnect` for more details. + Cache ~~~~~ diff --git a/docs/topics/async.txt b/docs/topics/async.txt index 95d3435e07..5a2324af5e 100644 --- a/docs/topics/async.txt +++ b/docs/topics/async.txt @@ -136,6 +136,26 @@ a purely synchronous codebase under ASGI because the request-handling code is still all running asynchronously. In general you will only want to enable ASGI mode if you have asynchronous code in your project. +.. _async-handling-disconnect: + +Handling disconnects +-------------------- + +.. versionadded:: 5.0 + +For long-lived requests, a client may disconnect before the view returns a +response. In this case, an ``asyncio.CancelledError`` will be raised in the +view. You can catch this error and handle it if you need to perform any +cleanup:: + + async def my_view(request): + try: + # Do some work + ... + except asyncio.CancelledError: + # Handle disconnect + raise + .. _async-safety: Async safety diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py index fc22a992a7..0222b5356e 100644 --- a/tests/asgi/tests.py +++ b/tests/asgi/tests.py @@ -7,8 +7,10 @@ from asgiref.testing import ApplicationCommunicator from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler from django.core.asgi import get_asgi_application +from django.core.handlers.asgi import ASGIHandler, ASGIRequest from django.core.signals import request_finished, request_started from django.db import close_old_connections +from django.http import HttpResponse from django.test import ( AsyncRequestFactory, SimpleTestCase, @@ -16,6 +18,7 @@ from django.test import ( modify_settings, override_settings, ) +from django.urls import path from django.utils.http import http_date from .urls import sync_waiter, test_filename @@ -234,6 +237,34 @@ class ASGITest(SimpleTestCase): with self.assertRaises(asyncio.TimeoutError): await communicator.receive_output() + async def test_disconnect_with_body(self): + application = get_asgi_application() + scope = self.async_request_factory._base_scope(path="/") + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request", "body": b"some body"}) + await communicator.send_input({"type": "http.disconnect"}) + with self.assertRaises(asyncio.TimeoutError): + await communicator.receive_output() + + async def test_assert_in_listen_for_disconnect(self): + application = get_asgi_application() + scope = self.async_request_factory._base_scope(path="/") + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request"}) + await communicator.send_input({"type": "http.not_a_real_message"}) + msg = "Invalid ASGI message after request body: http.not_a_real_message" + with self.assertRaisesMessage(AssertionError, msg): + await communicator.receive_output() + + async def test_delayed_disconnect_with_body(self): + application = get_asgi_application() + scope = self.async_request_factory._base_scope(path="/delayed_hello/") + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request", "body": b"some body"}) + await communicator.send_input({"type": "http.disconnect"}) + with self.assertRaises(asyncio.TimeoutError): + await communicator.receive_output() + async def test_wrong_connection_type(self): application = get_asgi_application() scope = self.async_request_factory._base_scope(path="/", type="other") @@ -318,3 +349,56 @@ class ASGITest(SimpleTestCase): self.assertEqual(len(sync_waiter.active_threads), 2) sync_waiter.active_threads.clear() + + async def test_asyncio_cancel_error(self): + # Flag to check if the view was cancelled. + view_did_cancel = False + + # A view that will listen for the cancelled error. + async def view(request): + nonlocal view_did_cancel + try: + await asyncio.sleep(0.2) + return HttpResponse("Hello World!") + except asyncio.CancelledError: + # Set the flag. + view_did_cancel = True + raise + + # Request class to use the view. + class TestASGIRequest(ASGIRequest): + urlconf = (path("cancel/", view),) + + # Handler to use request class. + class TestASGIHandler(ASGIHandler): + request_class = TestASGIRequest + + # Request cycle should complete since no disconnect was sent. + application = TestASGIHandler() + scope = self.async_request_factory._base_scope(path="/cancel/") + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request"}) + response_start = await communicator.receive_output() + self.assertEqual(response_start["type"], "http.response.start") + self.assertEqual(response_start["status"], 200) + response_body = await communicator.receive_output() + self.assertEqual(response_body["type"], "http.response.body") + self.assertEqual(response_body["body"], b"Hello World!") + # Give response.close() time to finish. + await communicator.wait() + self.assertIs(view_did_cancel, False) + + # Request cycle with a disconnect before the view can respond. + application = TestASGIHandler() + scope = self.async_request_factory._base_scope(path="/cancel/") + communicator = ApplicationCommunicator(application, scope) + await communicator.send_input({"type": "http.request"}) + # Let the view actually start. + await asyncio.sleep(0.1) + # Disconnect the client. + await communicator.send_input({"type": "http.disconnect"}) + # The handler should not send a response. + with self.assertRaises(asyncio.TimeoutError): + await communicator.receive_output() + await communicator.wait() + self.assertIs(view_did_cancel, True) diff --git a/tests/asgi/urls.py b/tests/asgi/urls.py index 34595c1b6c..0f74fc9b97 100644 --- a/tests/asgi/urls.py +++ b/tests/asgi/urls.py @@ -1,4 +1,5 @@ import threading +import time from django.http import FileResponse, HttpResponse from django.urls import path @@ -10,6 +11,12 @@ def hello(request): return HttpResponse("Hello %s!" % name) +def hello_with_delay(request): + name = request.GET.get("name") or "World" + time.sleep(1) + return HttpResponse(f"Hello {name}!") + + def hello_meta(request): return HttpResponse( "From %s" % request.META.get("HTTP_REFERER") or "", @@ -46,4 +53,5 @@ urlpatterns = [ path("meta/", hello_meta), path("post/", post_echo), path("wait/", sync_waiter), + path("delayed_hello/", hello_with_delay), ]