Source code for kedro.runner.thread_runner

# Copyright 2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.
"""``ThreadRunner`` is an ``AbstractRunner`` implementation. It can
be used to run the ``Pipeline`` in parallel groups formed by toposort
using threads.
"""
import warnings
from collections import Counter
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from itertools import chain
from typing import Set

from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node
from kedro.runner.runner import AbstractRunner, run_node


[docs]class ThreadRunner(AbstractRunner): """``ThreadRunner`` is an ``AbstractRunner`` implementation. It can be used to run the ``Pipeline`` in parallel groups formed by toposort using threads. """
[docs] def __init__(self, max_workers: int = None, is_async: bool = False): """ Instantiates the runner. Args: max_workers: Number of worker processes to spawn. If not set, calculated automatically based on the pipeline configuration and CPU core count. is_async: If True, set to False, because `ThreadRunner` doesn't support loading and saving the node inputs and outputs asynchronously with threads. Defaults to False. Raises: ValueError: bad parameters passed """ if is_async: warnings.warn( "`ThreadRunner` doesn't support loading and saving the " "node inputs and outputs asynchronously with threads. " "Setting `is_async` to False." ) super().__init__(is_async=False) if max_workers is not None and max_workers <= 0: raise ValueError("max_workers should be positive") self._max_workers = max_workers
[docs] def create_default_data_set(self, ds_name: str) -> AbstractDataSet: """Factory method for creating the default data set for the runner. Args: ds_name: Name of the missing data set Returns: An instance of an implementation of AbstractDataSet to be used for all unregistered data sets. """ return MemoryDataSet()
def _get_required_workers_count(self, pipeline: Pipeline): """ Calculate the max number of processes required for the pipeline """ # Number of nodes is a safe upper-bound estimate. # It's also safe to reduce it by the number of layers minus one, # because each layer means some nodes depend on other nodes # and they can not run in parallel. # It might be not a perfect solution, but good enough and simple. required_threads = len(pipeline.nodes) - len(pipeline.grouped_nodes) + 1 return ( min(required_threads, self._max_workers) if self._max_workers else required_threads ) def _run( # pylint: disable=too-many-locals,useless-suppression self, pipeline: Pipeline, catalog: DataCatalog, run_id: str = None ) -> None: """The abstract interface for running pipelines. Args: pipeline: The ``Pipeline`` to run. catalog: The ``DataCatalog`` from which to fetch data. run_id: The id of the run. Raises: Exception: in case of any downstream node failure. """ nodes = pipeline.nodes load_counts = Counter(chain.from_iterable(n.inputs for n in nodes)) node_dependencies = pipeline.node_dependencies todo_nodes = set(node_dependencies.keys()) done_nodes = set() # type: Set[Node] futures = set() done = None max_workers = self._get_required_workers_count(pipeline) with ThreadPoolExecutor(max_workers=max_workers) as pool: while True: ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} todo_nodes -= ready for node in ready: futures.add( pool.submit(run_node, node, catalog, self._is_async, run_id) ) if not futures: assert not todo_nodes, (todo_nodes, done_nodes, ready, done) break done, futures = wait(futures, return_when=FIRST_COMPLETED) for future in done: try: node = future.result() except Exception: self._suggest_resume_scenario(pipeline, done_nodes) raise done_nodes.add(node) # decrement load counts and release any data sets we've finished # with this is particularly important for the shared datasets we # create above for data_set in node.inputs: load_counts[data_set] -= 1 if ( load_counts[data_set] < 1 and data_set not in pipeline.inputs() ): catalog.release(data_set) for data_set in node.outputs: if ( load_counts[data_set] < 1 and data_set not in pipeline.outputs() ): catalog.release(data_set)