Source code for ats_linter.linter

"""Copyright (c) 2023 Aydin Abdi.

Module to lint test files.

This module provides a class to lint test files.
"""

from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass, field
from threading import Lock
from typing import Any

from loguru import logger

from ats_linter.data_classes import Section, TestCase
from ats_linter.description import (
    SECTION_APPROVALS,
    SECTION_DATA_DRIVEN_TEST,
    SECTION_OBJECTIVE,
    SECTION_PRECONDITIONS,
    SECTION_TEST_STEPS,
    TestDescription,
)

# Comment out to enable logging
logger.disable("__name__")

MANDATORY_SECTIONS = [SECTION_OBJECTIVE, SECTION_APPROVALS, SECTION_TEST_STEPS]
OPTIONAL_SECTIONS = [SECTION_PRECONDITIONS, SECTION_DATA_DRIVEN_TEST]
SECTION_NAMES = MANDATORY_SECTIONS + OPTIONAL_SECTIONS

MISSING_SECTION_ERROR_MESSAGE = "Missing '{section_name}' section"
MISMATCH_APPROVALS_VERIFY_ERROR_MESSAGE = (
    "Mismatch between amount of 'Approvals'='{approvals}'"
    "and 'Verify steps'='{verifies}' sections"
)


[docs] @dataclass class ATSTestCase: """Represents a ATS test case. Parameters ---------- test_case: The test case. test_description: ATS test case description :class: `TestDescription`. """ test_case: TestCase test_description: TestDescription = field(init=False, default=None) def __post_init__(self): """Post init method to parse docstring and create sections.""" from ats_linter.description import TestDescriptionFactory self.test_description = TestDescriptionFactory.from_docstring( self.test_case.docstring, ) logger.debug(f"ATS test description: {self.test_description}") def __dict__(self) -> dict[str, Any]: """Return the ATS test description as a dict. Returns: The ATS test description as a dict. """ return asdict(self.test_description) def __len__(self) -> int: """Return number of verify steps. Returns: Number of verify steps. """ nbr_of_verify_steps = self.test_description.verify_steps.__len__() logger.debug( f"Verify steps: {self.test_description.verify_steps}, " f"number of verify steps: {nbr_of_verify_steps}", ) return nbr_of_verify_steps
[docs] @dataclass class ATSTestCasesFactory: """Factory class to create :class: `ATSTestCase` objects. Parameters ---------- test_cases: The list of :class: `TestCase` objects. ats_test_cases: The list of :class: `ATSTestCase` objects. """ test_cases: list[TestCase] ats_test_cases: list[ATSTestCase] = field(init=False, default_factory=list) def __post_init__(self): """Post init method to create :class: `ATSTestCase` objects.""" self._create_ats_test_cases() def _create_ats_test_cases(self) -> None: """Create :class: `ATSTestCase` objects from the test cases.""" with ThreadPoolExecutor() as executor: futures = { executor.submit(self._create_ats_test_case, test_case) for test_case in self.test_cases } for future in as_completed(futures): self.ats_test_cases.append(future.result()) def _create_ats_test_case(self, test_case: TestCase) -> ATSTestCase: """Create a :class: `ATSTestCase` object from a test case. Args: test_case: The test case to create a :class: `ATSTestCase` object from. Returns: The :class: `ATSTestCase` object created from the test case. """ return ATSTestCase(test_case) def __len__(self) -> int: """Return number of :class: `ATSTestCase` objects. Returns: Number of :class: `ATSTestCase` objects. """ return len(self.ats_test_cases)
[docs] @dataclass class LintResult: """Class to represent the result of linting a test case.""" module_name: str class_name: str test_name: str sections: list[Section] result: bool
[docs] @dataclass class LintTestCase: """Class to lint ATS test cases. Parameters ---------- ats_test_case: The ATS test case :class: `ATSTestCase`. test_case: The test case :class: `TestCase` to lint. test_description: The ATS test case description :class: `TestDescription`. sections: The list of sections in the test case. Example: (Doctest temporarily disabled due to API complexity) # >>> from ats_linter.data_classes import TestCase # >>> from ats_linter.linter import ATSTestCase, LintTestCase # >>> test_case = TestCase( # ... name="Test case name", # ... docstring="Test case description", # ... code="def test_something(): pass", # ... ) # >>> ats_test_case = ATSTestCase(test_case) # >>> test_case_linter = LintTestCase(ats_test_case) # >>> test_case_linter.lint() # >>> test_case_linter.sections """ ats_test_case: ATSTestCase test_case: TestCase = field(init=False) test_description: TestDescription = field(init=False) sections: list[Section] = field(init=False, default_factory=list) lint_result: LintResult = field(init=False) def __post_init__(self): """Post init method to parse docstring and create sections.""" self.test_case = self.ats_test_case.test_case self.test_description = self.ats_test_case.test_description self.sections = [ Section(name=section_name, error_message=None) for section_name in SECTION_NAMES ] self.lint_result = None def _check_section_presence(self, section: Section) -> None: """Check the presence of a section in the docstring. Args: section: The section to check. """ dict_docstring = self.test_description try: section_content = getattr( dict_docstring, section.name.lower().replace(" ", "_"), ) if not section_content: section.error_message = MISSING_SECTION_ERROR_MESSAGE.format( section_name=section.name, ) except AttributeError: section.error_message = MISSING_SECTION_ERROR_MESSAGE.format( section_name=section.name, ) def _check_sections(self, section_names: list[str]) -> None: """Check the presence of multiple sections in the docstring. Args: section_names: The list of section names to check. """ for section_name in section_names: section = next( (section for section in self.sections if section.name == section_name), None, ) if section.name in MANDATORY_SECTIONS: self._check_section_presence(section) def _check_mandatory_sections(self) -> None: """Check the presence of mandatory sections.""" self._check_sections(MANDATORY_SECTIONS) def _check_matching_approvals_and_steps(self) -> None: """Check if the number of approvals matches the number of verify steps.""" nbr_of_approvals = len(self.test_description.approvals) nbr_of_verify_steps = len(self.test_description.verify_steps) logger.debug(f"Number of approvals: {nbr_of_approvals}") logger.debug(f"Number of verify steps: {nbr_of_verify_steps}") if ( nbr_of_approvals and nbr_of_verify_steps and nbr_of_approvals != nbr_of_verify_steps ): self.sections.append( Section( name="matching_approvals_steps", error_message=MISMATCH_APPROVALS_VERIFY_ERROR_MESSAGE.format( approvals=nbr_of_approvals, verifies=nbr_of_verify_steps, ), ), )
[docs] def lint(self) -> bool: """Lint the test case docstring and return the linting result. Returns: True if the test case docstring passes linting, False otherwise. """ # Check for missing docstring if not self.test_case.docstring or not self.test_case.docstring.strip(): self.sections.append( Section( name="docstring", error_message="Missing docstring for test case.", ) ) logger.error( f"Test case '{self.test_case.name}' failed linting: Missing docstring." ) return False self._check_mandatory_sections() self._check_matching_approvals_and_steps() failed_sections = { section.name: section.error_message for section in self.sections if section.error_message } if failed_sections: logger.error( f"Test case '{self.test_case.name}' " "failed linting for the following reasons:", ) logger.error( "\n".join( f"- {section}: {error}" for section, error in failed_sections.items() ), ) return False logger.info(f"Test case '{self.test_case.name}' passed linting.") return True
[docs] @dataclass class ATSTestCasesLinter: """Class to lint multiple ATS test cases. Parameters ---------- ats_test_cases: The list of ATS test cases to lint. Example: (Doctest temporarily disabled due to API complexity) # >>> from ats_linter.data_classes import TestCase # >>> from ats_linter.linter import ATSTestCasesFactory, ATSTestCasesLinter # >>> test_case_1 = TestCase( # ... name="Test case name 1", # ... docstring="Test case description 1", # ... code="def test_something_1(): pass", # ... ) # >>> test_case_2 = TestCase( # ... name="Test case name 2", # ... docstring="Test case description 2", # ... code="def test_something_2(): pass", # ... ) # >>> factory = ATSTestCasesFactory([test_case_1, test_case_2]) # >>> ats_test_cases_linter = ATSTestCasesLinter(factory.ats_test_cases) # >>> ats_test_cases_linter.lint() """ ats_test_cases: list[ATSTestCase] lint_results: dict[str, Any] = field(init=False, default_factory=dict) def __post_init__(self): """Post init method to lint ATS test cases in parallel.""" self.lint_results = {}
[docs] def lint(self) -> bool: """Lint the test case docstring and return the linting result. Returns: True if the test case docstring passes linting, False otherwise. """ if not self.ats_test_cases: return True max_workers = len(self.ats_test_cases) lock = Lock() from ats_linter.linter import lint_ats_test_case all_passed = True with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ executor.submit( lint_ats_test_case, ats_test_case, self.lint_results, lock, ) for ats_test_case in self.ats_test_cases ] for future in as_completed(futures): if not future.result(): all_passed = False # pragma: no cover return all_passed
# Module-level function for direct import and testing
[docs] def lint_ats_test_case( ats_test_case: "ATSTestCase", lint_results: dict[str, Any], lock: Lock, ) -> bool: """Lint a single test case. Args: ats_test_case: The ATS test case to lint. lint_results: The dictionary to store linting results. lock: The lock to ensure thread-safe access to the results dictionary. Returns: True if the test case passes linting, False otherwise. """ lint_result = False try: lint_result = LintTestCase(ats_test_case).lint() # Ensure that the dictionary is accessed in a thread-safe manner with lock: # Add the lint result to the dictionary lint_results.update({ats_test_case.test_case.name: {"status": lint_result}}) except Exception as e: logger.error(f"Failed to lint test case '{ats_test_case.test_case.name}': {e}") with lock: # Add the failed lint result to the dictionary lint_results.update({ats_test_case.test_case.name: {"status": lint_result}}) return lint_result