Source code for scrapy.contracts

from __future__ import annotations

import re
import sys
from collections.abc import AsyncGenerator, Iterable
from functools import wraps
from inspect import getmembers
from types import CoroutineType
from typing import TYPE_CHECKING, Any, cast
from unittest import TestCase, TestResult

from scrapy.http import Request, Response
from scrapy.utils.python import get_spec
from scrapy.utils.spider import iterate_spider_output

if TYPE_CHECKING:
    from collections.abc import Callable

    from twisted.python.failure import Failure

    from scrapy import Spider


[docs]class Contract: """Abstract class for contracts""" request_cls: type[Request] | None = None name: str def __init__(self, method: Callable, *args: Any): self.testcase_pre = _create_testcase(method, f"@{self.name} pre-hook") self.testcase_post = _create_testcase(method, f"@{self.name} post-hook") self.args: tuple[Any, ...] = args def add_pre_hook(self, request: Request, results: TestResult) -> Request: if hasattr(self, "pre_process"): cb = request.callback assert cb is not None @wraps(cb) def wrapper(response: Response, **cb_kwargs: Any) -> list[Any]: try: results.startTest(self.testcase_pre) self.pre_process(response) results.stopTest(self.testcase_pre) except AssertionError: results.addFailure(self.testcase_pre, sys.exc_info()) except Exception: results.addError(self.testcase_pre, sys.exc_info()) else: results.addSuccess(self.testcase_pre) cb_result = cb(response, **cb_kwargs) if isinstance(cb_result, (AsyncGenerator, CoroutineType)): raise TypeError("Contracts don't support async callbacks") return list(cast(Iterable[Any], iterate_spider_output(cb_result))) request.callback = wrapper return request def add_post_hook(self, request: Request, results: TestResult) -> Request: if hasattr(self, "post_process"): cb = request.callback assert cb is not None @wraps(cb) def wrapper(response: Response, **cb_kwargs: Any) -> list[Any]: cb_result = cb(response, **cb_kwargs) if isinstance(cb_result, (AsyncGenerator, CoroutineType)): raise TypeError("Contracts don't support async callbacks") output = list(cast(Iterable[Any], iterate_spider_output(cb_result))) try: results.startTest(self.testcase_post) self.post_process(output) results.stopTest(self.testcase_post) except AssertionError: results.addFailure(self.testcase_post, sys.exc_info()) except Exception: results.addError(self.testcase_post, sys.exc_info()) else: results.addSuccess(self.testcase_post) return output request.callback = wrapper return request
[docs] def adjust_request_args(self, args: dict[str, Any]) -> dict[str, Any]: return args
class ContractsManager: contracts: dict[str, type[Contract]] = {} def __init__(self, contracts: Iterable[type[Contract]]): for contract in contracts: self.contracts[contract.name] = contract def tested_methods_from_spidercls(self, spidercls: type[Spider]) -> list[str]: is_method = re.compile(r"^\s*@", re.MULTILINE).search methods = [] for key, value in getmembers(spidercls): if callable(value) and value.__doc__ and is_method(value.__doc__): methods.append(key) return methods def extract_contracts(self, method: Callable) -> list[Contract]: contracts: list[Contract] = [] assert method.__doc__ is not None for line in method.__doc__.split("\n"): line = line.strip() if line.startswith("@"): m = re.match(r"@(\w+)\s*(.*)", line) if m is None: continue name, args = m.groups() args = re.split(r"\s+", args) contracts.append(self.contracts[name](method, *args)) return contracts def from_spider(self, spider: Spider, results: TestResult) -> list[Request | None]: requests: list[Request | None] = [] for method in self.tested_methods_from_spidercls(type(spider)): bound_method = spider.__getattribute__(method) try: requests.append(self.from_method(bound_method, results)) except Exception: case = _create_testcase(bound_method, "contract") results.addError(case, sys.exc_info()) return requests def from_method(self, method: Callable, results: TestResult) -> Request | None: contracts = self.extract_contracts(method) if contracts: request_cls = Request for contract in contracts: if contract.request_cls is not None: request_cls = contract.request_cls # calculate request args args, kwargs = get_spec(request_cls.__init__) # Don't filter requests to allow # testing different callbacks on the same URL. kwargs["dont_filter"] = True kwargs["callback"] = method for contract in contracts: kwargs = contract.adjust_request_args(kwargs) args.remove("self") # check if all positional arguments are defined in kwargs if set(args).issubset(set(kwargs)): request = request_cls(**kwargs) # execute pre and post hooks in order for contract in reversed(contracts): request = contract.add_pre_hook(request, results) for contract in contracts: request = contract.add_post_hook(request, results) self._clean_req(request, method, results) return request return None def _clean_req( self, request: Request, method: Callable, results: TestResult ) -> None: """stop the request from returning objects and records any errors""" cb = request.callback assert cb is not None @wraps(cb) def cb_wrapper(response: Response, **cb_kwargs: Any) -> None: try: output = cb(response, **cb_kwargs) output = list(cast(Iterable[Any], iterate_spider_output(output))) except Exception: case = _create_testcase(method, "callback") results.addError(case, sys.exc_info()) def eb_wrapper(failure: Failure) -> None: case = _create_testcase(method, "errback") exc_info = failure.type, failure.value, failure.getTracebackObject() results.addError(case, exc_info) request.callback = cb_wrapper request.errback = eb_wrapper def _create_testcase(method: Callable, desc: str) -> TestCase: spider = method.__self__.name # type: ignore[attr-defined] class ContractTestCase(TestCase): def __str__(_self) -> str: return f"[{spider}] {method.__name__} ({desc})" name = f"{spider}_{method.__name__}" setattr(ContractTestCase, name, lambda x: x) return ContractTestCase(name)