Best Python code snippet using gherkin-python
dataflow_runner.py
Source:dataflow_runner.py  
1#2# Licensed to the Apache Software Foundation (ASF) under one or more3# contributor license agreements.  See the NOTICE file distributed with4# this work for additional information regarding copyright ownership.5# The ASF licenses this file to You under the Apache License, Version 2.06# (the "License"); you may not use this file except in compliance with7# the License.  You may obtain a copy of the License at8#9#    http://www.apache.org/licenses/LICENSE-2.010#11# Unless required by applicable law or agreed to in writing, software12# distributed under the License is distributed on an "AS IS" BASIS,13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.14# See the License for the specific language governing permissions and15# limitations under the License.16#17"""A runner implementation that submits a job for remote execution.18The runner will create a JSON description of the job graph and then submit it19to the Dataflow Service for remote execution by a worker.20"""21import logging22import threading23import time24import traceback25import urllib26import apache_beam as beam27from apache_beam import error28from apache_beam import coders29from apache_beam import pvalue30from apache_beam.internal import pickler31from apache_beam.internal.gcp import json_value32from apache_beam.pvalue import AsSideInput33from apache_beam.runners.dataflow.dataflow_metrics import DataflowMetrics34from apache_beam.runners.dataflow.internal import names35from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api36from apache_beam.runners.dataflow.internal.names import PropertyNames37from apache_beam.runners.dataflow.internal.names import TransformNames38from apache_beam.runners.dataflow.ptransform_overrides import CreatePTransformOverride39from apache_beam.runners.runner import PValueCache40from apache_beam.runners.runner import PipelineResult41from apache_beam.runners.runner import PipelineRunner42from apache_beam.runners.runner import PipelineState43from apache_beam.transforms.display import DisplayData44from apache_beam.typehints import typehints45from apache_beam.options.pipeline_options import SetupOptions46from apache_beam.options.pipeline_options import StandardOptions47from apache_beam.options.pipeline_options import TestOptions48from apache_beam.utils.plugin import BeamPlugin49__all__ = ['DataflowRunner']50class DataflowRunner(PipelineRunner):51  """A runner that creates job graphs and submits them for remote execution.52  Every execution of the run() method will submit an independent job for53  remote execution that consists of the nodes reachable from the passed in54  node argument or entire graph if node is None. The run() method returns55  after the service created the job and  will not wait for the job to finish56  if blocking is set to False.57  """58  # A list of PTransformOverride objects to be applied before running a pipeline59  # using DataflowRunner.60  # Currently this only works for overrides where the input and output types do61  # not change.62  # For internal SDK use only. This should not be updated by Beam pipeline63  # authors.64  _PTRANSFORM_OVERRIDES = [65      CreatePTransformOverride(),66  ]67  def __init__(self, cache=None):68    # Cache of CloudWorkflowStep protos generated while the runner69    # "executes" a pipeline.70    self._cache = cache if cache is not None else PValueCache()71    self._unique_step_id = 072  def _get_unique_step_name(self):73    self._unique_step_id += 174    return 's%s' % self._unique_step_id75  @staticmethod76  def poll_for_job_completion(runner, result):77    """Polls for the specified job to finish running (successfully or not)."""78    last_message_time = None79    last_message_hash = None80    last_error_rank = float('-inf')81    last_error_msg = None82    last_job_state = None83    # How long to wait after pipeline failure for the error84    # message to show up giving the reason for the failure.85    # It typically takes about 30 seconds.86    final_countdown_timer_secs = 50.087    sleep_secs = 5.088    # Try to prioritize the user-level traceback, if any.89    def rank_error(msg):90      if 'work item was attempted' in msg:91        return -192      elif 'Traceback' in msg:93        return 194      return 095    job_id = result.job_id()96    while True:97      response = runner.dataflow_client.get_job(job_id)98      # If get() is called very soon after Create() the response may not contain99      # an initialized 'currentState' field.100      if response.currentState is not None:101        if response.currentState != last_job_state:102          logging.info('Job %s is in state %s', job_id, response.currentState)103          last_job_state = response.currentState104        if str(response.currentState) != 'JOB_STATE_RUNNING':105          # Stop checking for new messages on timeout, explanatory106          # message received, success, or a terminal job state caused107          # by the user that therefore doesn't require explanation.108          if (final_countdown_timer_secs <= 0.0109              or last_error_msg is not None110              or str(response.currentState) == 'JOB_STATE_DONE'111              or str(response.currentState) == 'JOB_STATE_CANCELLED'112              or str(response.currentState) == 'JOB_STATE_UPDATED'113              or str(response.currentState) == 'JOB_STATE_DRAINED'):114            break115          # The job has failed; ensure we see any final error messages.116          sleep_secs = 1.0      # poll faster during the final countdown117          final_countdown_timer_secs -= sleep_secs118      time.sleep(sleep_secs)119      # Get all messages since beginning of the job run or since last message.120      page_token = None121      while True:122        messages, page_token = runner.dataflow_client.list_messages(123            job_id, page_token=page_token, start_time=last_message_time)124        for m in messages:125          message = '%s: %s: %s' % (m.time, m.messageImportance, m.messageText)126          m_hash = hash(message)127          if last_message_hash is not None and m_hash == last_message_hash:128            # Skip the first message if it is the last message we got in the129            # previous round. This can happen because we use the130            # last_message_time as a parameter of the query for new messages.131            continue132          last_message_time = m.time133          last_message_hash = m_hash134          # Skip empty messages.135          if m.messageImportance is None:136            continue137          logging.info(message)138          if str(m.messageImportance) == 'JOB_MESSAGE_ERROR':139            if rank_error(m.messageText) >= last_error_rank:140              last_error_rank = rank_error(m.messageText)141              last_error_msg = m.messageText142        if not page_token:143          break144    result._job = response145    runner.last_error_msg = last_error_msg146  @staticmethod147  def group_by_key_input_visitor():148    # Imported here to avoid circular dependencies.149    from apache_beam.pipeline import PipelineVisitor150    class GroupByKeyInputVisitor(PipelineVisitor):151      """A visitor that replaces `Any` element type for input `PCollection` of152      a `GroupByKey` or `_GroupByKeyOnly` with a `KV` type.153      TODO(BEAM-115): Once Python SDk is compatible with the new Runner API,154      we could directly replace the coder instead of mutating the element type.155      """156      def visit_transform(self, transform_node):157        # Imported here to avoid circular dependencies.158        # pylint: disable=wrong-import-order, wrong-import-position159        from apache_beam.transforms.core import GroupByKey, _GroupByKeyOnly160        if isinstance(transform_node.transform, (GroupByKey, _GroupByKeyOnly)):161          pcoll = transform_node.inputs[0]162          input_type = pcoll.element_type163          # If input_type is not specified, then treat it as `Any`.164          if not input_type:165            input_type = typehints.Any166          if not isinstance(input_type, typehints.TupleHint.TupleConstraint):167            if isinstance(input_type, typehints.AnyTypeConstraint):168              # `Any` type needs to be replaced with a KV[Any, Any] to169              # force a KV coder as the main output coder for the pcollection170              # preceding a GroupByKey.171              pcoll.element_type = typehints.KV[typehints.Any, typehints.Any]172            else:173              # TODO: Handle other valid types,174              # e.g. Union[KV[str, int], KV[str, float]]175              raise ValueError(176                  "Input to GroupByKey must be of Tuple or Any type. "177                  "Found %s for %s" % (input_type, pcoll))178    return GroupByKeyInputVisitor()179  @staticmethod180  def flatten_input_visitor():181    # Imported here to avoid circular dependencies.182    from apache_beam.pipeline import PipelineVisitor183    class FlattenInputVisitor(PipelineVisitor):184      """A visitor that replaces the element type for input ``PCollections``s of185       a ``Flatten`` transform with that of the output ``PCollection``.186      """187      def visit_transform(self, transform_node):188        # Imported here to avoid circular dependencies.189        # pylint: disable=wrong-import-order, wrong-import-position190        from apache_beam import Flatten191        if isinstance(transform_node.transform, Flatten):192          output_pcoll = transform_node.outputs[None]193          for input_pcoll in transform_node.inputs:194            input_pcoll.element_type = output_pcoll.element_type195    return FlattenInputVisitor()196  def run(self, pipeline):197    """Remotely executes entire pipeline or parts reachable from node."""198    # Import here to avoid adding the dependency for local running scenarios.199    try:200      # pylint: disable=wrong-import-order, wrong-import-position201      from apache_beam.runners.dataflow.internal import apiclient202    except ImportError:203      raise ImportError(204          'Google Cloud Dataflow runner not available, '205          'please install apache_beam[gcp]')206    # Performing configured PTransform overrides.207    pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES)208    # Add setup_options for all the BeamPlugin imports209    setup_options = pipeline._options.view_as(SetupOptions)210    plugins = BeamPlugin.get_all_plugin_paths()211    if setup_options.beam_plugins is not None:212      plugins = list(set(plugins + setup_options.beam_plugins))213    setup_options.beam_plugins = plugins214    self.job = apiclient.Job(pipeline._options)215    # Dataflow runner requires a KV type for GBK inputs, hence we enforce that216    # here.217    pipeline.visit(self.group_by_key_input_visitor())218    # Dataflow runner requires output type of the Flatten to be the same as the219    # inputs, hence we enforce that here.220    pipeline.visit(self.flatten_input_visitor())221    # The superclass's run will trigger a traversal of all reachable nodes.222    super(DataflowRunner, self).run(pipeline)223    test_options = pipeline._options.view_as(TestOptions)224    # If it is a dry run, return without submitting the job.225    if test_options.dry_run:226      return None227    # Get a Dataflow API client and set its options228    self.dataflow_client = apiclient.DataflowApplicationClient(229        pipeline._options)230    # Create the job231    result = DataflowPipelineResult(232        self.dataflow_client.create_job(self.job), self)233    self._metrics = DataflowMetrics(self.dataflow_client, result, self.job)234    result.metric_results = self._metrics235    return result236  def _get_typehint_based_encoding(self, typehint, window_coder):237    """Returns an encoding based on a typehint object."""238    return self._get_cloud_encoding(self._get_coder(typehint,239                                                    window_coder=window_coder))240  @staticmethod241  def _get_coder(typehint, window_coder):242    """Returns a coder based on a typehint object."""243    if window_coder:244      return coders.WindowedValueCoder(245          coders.registry.get_coder(typehint),246          window_coder=window_coder)247    return coders.registry.get_coder(typehint)248  def _get_cloud_encoding(self, coder):249    """Returns an encoding based on a coder object."""250    if not isinstance(coder, coders.Coder):251      raise TypeError('Coder object must inherit from coders.Coder: %s.' %252                      str(coder))253    return coder.as_cloud_object()254  def _get_side_input_encoding(self, input_encoding):255    """Returns an encoding for the output of a view transform.256    Args:257      input_encoding: encoding of current transform's input. Side inputs need258        this because the service will check that input and output types match.259    Returns:260      An encoding that matches the output and input encoding. This is essential261      for the View transforms introduced to produce side inputs to a ParDo.262    """263    return {264        '@type': input_encoding['@type'],265        'component_encodings': [input_encoding]266    }267  def _get_encoded_output_coder(self, transform_node, window_value=True):268    """Returns the cloud encoding of the coder for the output of a transform."""269    if (len(transform_node.outputs) == 1270        and transform_node.outputs[None].element_type is not None):271      # TODO(robertwb): Handle type hints for multi-output transforms.272      element_type = transform_node.outputs[None].element_type273    else:274      # TODO(silviuc): Remove this branch (and assert) when typehints are275      # propagated everywhere. Returning an 'Any' as type hint will trigger276      # usage of the fallback coder (i.e., cPickler).277      element_type = typehints.Any278    if window_value:279      window_coder = (280          transform_node.outputs[None].windowing.windowfn.get_window_coder())281    else:282      window_coder = None283    return self._get_typehint_based_encoding(284        element_type, window_coder=window_coder)285  def _add_step(self, step_kind, step_label, transform_node, side_tags=()):286    """Creates a Step object and adds it to the cache."""287    # Import here to avoid adding the dependency for local running scenarios.288    # pylint: disable=wrong-import-order, wrong-import-position289    from apache_beam.runners.dataflow.internal import apiclient290    step = apiclient.Step(step_kind, self._get_unique_step_name())291    self.job.proto.steps.append(step.proto)292    step.add_property(PropertyNames.USER_NAME, step_label)293    # Cache the node/step association for the main output of the transform node.294    self._cache.cache_output(transform_node, None, step)295    # If side_tags is not () then this is a multi-output transform node and we296    # need to cache the (node, tag, step) for each of the tags used to access297    # the outputs. This is essential because the keys used to search in the298    # cache always contain the tag.299    for tag in side_tags:300      self._cache.cache_output(transform_node, tag, step)301    # Finally, we add the display data items to the pipeline step.302    # If the transform contains no display data then an empty list is added.303    step.add_property(304        PropertyNames.DISPLAY_DATA,305        [item.get_dict() for item in306         DisplayData.create_from(transform_node.transform).items])307    return step308  def _add_singleton_step(self, label, full_label, tag, input_step):309    """Creates a CollectionToSingleton step used to handle ParDo side inputs."""310    # Import here to avoid adding the dependency for local running scenarios.311    from apache_beam.runners.dataflow.internal import apiclient312    step = apiclient.Step(TransformNames.COLLECTION_TO_SINGLETON, label)313    self.job.proto.steps.append(step.proto)314    step.add_property(PropertyNames.USER_NAME, full_label)315    step.add_property(316        PropertyNames.PARALLEL_INPUT,317        {'@type': 'OutputReference',318         PropertyNames.STEP_NAME: input_step.proto.name,319         PropertyNames.OUTPUT_NAME: input_step.get_output(tag)})320    step.encoding = self._get_side_input_encoding(input_step.encoding)321    step.add_property(322        PropertyNames.OUTPUT_INFO,323        [{PropertyNames.USER_NAME: (324            '%s.%s' % (full_label, PropertyNames.OUTPUT)),325          PropertyNames.ENCODING: step.encoding,326          PropertyNames.OUTPUT_NAME: PropertyNames.OUT}])327    return step328  def run_Impulse(self, transform_node):329    standard_options = (330        transform_node.outputs[None].pipeline._options.view_as(StandardOptions))331    if standard_options.streaming:332      step = self._add_step(333          TransformNames.READ, transform_node.full_label, transform_node)334      step.add_property(PropertyNames.FORMAT, 'pubsub')335      step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION, '_starting_signal/')336      step.encoding = self._get_encoded_output_coder(transform_node)337      step.add_property(338          PropertyNames.OUTPUT_INFO,339          [{PropertyNames.USER_NAME: (340              '%s.%s' % (341                  transform_node.full_label, PropertyNames.OUT)),342            PropertyNames.ENCODING: step.encoding,343            PropertyNames.OUTPUT_NAME: PropertyNames.OUT}])344    else:345      ValueError('Impulse source for batch pipelines has not been defined.')346  def run_Flatten(self, transform_node):347    step = self._add_step(TransformNames.FLATTEN,348                          transform_node.full_label, transform_node)349    inputs = []350    for one_input in transform_node.inputs:351      input_step = self._cache.get_pvalue(one_input)352      inputs.append(353          {'@type': 'OutputReference',354           PropertyNames.STEP_NAME: input_step.proto.name,355           PropertyNames.OUTPUT_NAME: input_step.get_output(one_input.tag)})356    step.add_property(PropertyNames.INPUTS, inputs)357    step.encoding = self._get_encoded_output_coder(transform_node)358    step.add_property(359        PropertyNames.OUTPUT_INFO,360        [{PropertyNames.USER_NAME: (361            '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),362          PropertyNames.ENCODING: step.encoding,363          PropertyNames.OUTPUT_NAME: PropertyNames.OUT}])364  def apply_WriteToBigQuery(self, transform, pcoll):365    standard_options = pcoll.pipeline._options.view_as(StandardOptions)366    if standard_options.streaming:367      if (transform.write_disposition ==368          beam.io.BigQueryDisposition.WRITE_TRUNCATE):369        raise RuntimeError('Can not use write truncation mode in streaming')370      return self.apply_PTransform(transform, pcoll)371    else:372      return pcoll  | 'WriteToBigQuery' >> beam.io.Write(373          beam.io.BigQuerySink(374              transform.table_reference.tableId,375              transform.table_reference.datasetId,376              transform.table_reference.projectId,377              transform.schema,378              transform.create_disposition,379              transform.write_disposition))380  def apply_GroupByKey(self, transform, pcoll):381    # Infer coder of parent.382    #383    # TODO(ccy): make Coder inference and checking less specialized and more384    # comprehensive.385    parent = pcoll.producer386    if parent:387      coder = parent.transform._infer_output_coder()  # pylint: disable=protected-access388    if not coder:389      coder = self._get_coder(pcoll.element_type or typehints.Any, None)390    if not coder.is_kv_coder():391      raise ValueError(('Coder for the GroupByKey operation "%s" is not a '392                        'key-value coder: %s.') % (transform.label,393                                                   coder))394    # TODO(robertwb): Update the coder itself if it changed.395    coders.registry.verify_deterministic(396        coder.key_coder(), 'GroupByKey operation "%s"' % transform.label)397    return pvalue.PCollection(pcoll.pipeline)398  def run_GroupByKey(self, transform_node):399    input_tag = transform_node.inputs[0].tag400    input_step = self._cache.get_pvalue(transform_node.inputs[0])401    step = self._add_step(402        TransformNames.GROUP, transform_node.full_label, transform_node)403    step.add_property(404        PropertyNames.PARALLEL_INPUT,405        {'@type': 'OutputReference',406         PropertyNames.STEP_NAME: input_step.proto.name,407         PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)})408    step.encoding = self._get_encoded_output_coder(transform_node)409    step.add_property(410        PropertyNames.OUTPUT_INFO,411        [{PropertyNames.USER_NAME: (412            '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),413          PropertyNames.ENCODING: step.encoding,414          PropertyNames.OUTPUT_NAME: PropertyNames.OUT}])415    windowing = transform_node.transform.get_windowing(416        transform_node.inputs)417    step.add_property(418        PropertyNames.SERIALIZED_FN,419        self.serialize_windowing_strategy(windowing))420  def run_ParDo(self, transform_node):421    transform = transform_node.transform422    input_tag = transform_node.inputs[0].tag423    input_step = self._cache.get_pvalue(transform_node.inputs[0])424    # Attach side inputs.425    si_dict = {}426    # We must call self._cache.get_pvalue exactly once due to refcounting.427    si_labels = {}428    lookup_label = lambda side_pval: si_labels[side_pval]429    for side_pval in transform_node.side_inputs:430      assert isinstance(side_pval, AsSideInput)431      si_label = 'SideInput-' + self._get_unique_step_name()432      si_full_label = '%s/%s' % (transform_node.full_label, si_label)433      self._add_singleton_step(434          si_label, si_full_label, side_pval.pvalue.tag,435          self._cache.get_pvalue(side_pval.pvalue))436      si_dict[si_label] = {437          '@type': 'OutputReference',438          PropertyNames.STEP_NAME: si_label,439          PropertyNames.OUTPUT_NAME: PropertyNames.OUT}440      si_labels[side_pval] = si_label441    # Now create the step for the ParDo transform being handled.442    step = self._add_step(443        TransformNames.DO,444        transform_node.full_label + (445            '/Do' if transform_node.side_inputs else ''),446        transform_node,447        transform_node.transform.output_tags)448    fn_data = self._pardo_fn_data(transform_node, lookup_label)449    step.add_property(PropertyNames.SERIALIZED_FN, pickler.dumps(fn_data))450    step.add_property(451        PropertyNames.PARALLEL_INPUT,452        {'@type': 'OutputReference',453         PropertyNames.STEP_NAME: input_step.proto.name,454         PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)})455    # Add side inputs if any.456    step.add_property(PropertyNames.NON_PARALLEL_INPUTS, si_dict)457    # Generate description for the outputs. The output names458    # will be 'out' for main output and 'out_<tag>' for a tagged output.459    # Using 'out' as a tag will not clash with the name for main since it will460    # be transformed into 'out_out' internally.461    outputs = []462    step.encoding = self._get_encoded_output_coder(transform_node)463    # Add the main output to the description.464    outputs.append(465        {PropertyNames.USER_NAME: (466            '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),467         PropertyNames.ENCODING: step.encoding,468         PropertyNames.OUTPUT_NAME: PropertyNames.OUT})469    for side_tag in transform.output_tags:470      # The assumption here is that all outputs will have the same typehint471      # and coder as the main output. This is certainly the case right now472      # but conceivably it could change in the future.473      outputs.append(474          {PropertyNames.USER_NAME: (475              '%s.%s' % (transform_node.full_label, side_tag)),476           PropertyNames.ENCODING: step.encoding,477           PropertyNames.OUTPUT_NAME: (478               '%s_%s' % (PropertyNames.OUT, side_tag))})479    step.add_property(PropertyNames.OUTPUT_INFO, outputs)480  @staticmethod481  def _pardo_fn_data(transform_node, get_label):482    transform = transform_node.transform483    si_tags_and_types = [  # pylint: disable=protected-access484        (get_label(side_pval), side_pval.__class__, side_pval._view_options())485        for side_pval in transform_node.side_inputs]486    return (transform.fn, transform.args, transform.kwargs, si_tags_and_types,487            transform_node.inputs[0].windowing)488  def apply_CombineValues(self, transform, pcoll):489    return pvalue.PCollection(pcoll.pipeline)490  def run_CombineValues(self, transform_node):491    transform = transform_node.transform492    input_tag = transform_node.inputs[0].tag493    input_step = self._cache.get_pvalue(transform_node.inputs[0])494    step = self._add_step(495        TransformNames.COMBINE, transform_node.full_label, transform_node)496    # Combiner functions do not take deferred side-inputs (i.e. PValues) and497    # therefore the code to handle extra args/kwargs is simpler than for the498    # DoFn's of the ParDo transform. In the last, empty argument is where499    # side inputs information would go.500    fn_data = (transform.fn, transform.args, transform.kwargs, ())501    step.add_property(PropertyNames.SERIALIZED_FN,502                      pickler.dumps(fn_data))503    step.add_property(504        PropertyNames.PARALLEL_INPUT,505        {'@type': 'OutputReference',506         PropertyNames.STEP_NAME: input_step.proto.name,507         PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)})508    # Note that the accumulator must not have a WindowedValue encoding, while509    # the output of this step does in fact have a WindowedValue encoding.510    accumulator_encoding = self._get_encoded_output_coder(transform_node,511                                                          window_value=False)512    output_encoding = self._get_encoded_output_coder(transform_node)513    step.encoding = output_encoding514    step.add_property(PropertyNames.ENCODING, accumulator_encoding)515    # Generate description for main output 'out.'516    outputs = []517    # Add the main output to the description.518    outputs.append(519        {PropertyNames.USER_NAME: (520            '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),521         PropertyNames.ENCODING: step.encoding,522         PropertyNames.OUTPUT_NAME: PropertyNames.OUT})523    step.add_property(PropertyNames.OUTPUT_INFO, outputs)524  def run_Read(self, transform_node):525    transform = transform_node.transform526    step = self._add_step(527        TransformNames.READ, transform_node.full_label, transform_node)528    # TODO(mairbek): refactor if-else tree to use registerable functions.529    # Initialize the source specific properties.530    if not hasattr(transform.source, 'format'):531      # If a format is not set, we assume the source to be a custom source.532      source_dict = {}533      source_dict['spec'] = {534          '@type': names.SOURCE_TYPE,535          names.SERIALIZED_SOURCE_KEY: pickler.dumps(transform.source)536      }537      try:538        source_dict['metadata'] = {539            'estimated_size_bytes': json_value.get_typed_value_descriptor(540                transform.source.estimate_size())541        }542      except error.RuntimeValueProviderError:543        # Size estimation is best effort, and this error is by value provider.544        logging.info(545            'Could not estimate size of source %r due to ' + \546            'RuntimeValueProviderError', transform.source)547      except Exception:  # pylint: disable=broad-except548        # Size estimation is best effort. So we log the error and continue.549        logging.info(550            'Could not estimate size of source %r due to an exception: %s',551            transform.source, traceback.format_exc())552      step.add_property(PropertyNames.SOURCE_STEP_INPUT,553                        source_dict)554    elif transform.source.format == 'text':555      step.add_property(PropertyNames.FILE_PATTERN, transform.source.path)556    elif transform.source.format == 'bigquery':557      step.add_property(PropertyNames.BIGQUERY_EXPORT_FORMAT, 'FORMAT_AVRO')558      # TODO(silviuc): Add table validation if transform.source.validate.559      if transform.source.table_reference is not None:560        step.add_property(PropertyNames.BIGQUERY_DATASET,561                          transform.source.table_reference.datasetId)562        step.add_property(PropertyNames.BIGQUERY_TABLE,563                          transform.source.table_reference.tableId)564        # If project owning the table was not specified then the project owning565        # the workflow (current project) will be used.566        if transform.source.table_reference.projectId is not None:567          step.add_property(PropertyNames.BIGQUERY_PROJECT,568                            transform.source.table_reference.projectId)569      elif transform.source.query is not None:570        step.add_property(PropertyNames.BIGQUERY_QUERY, transform.source.query)571        step.add_property(PropertyNames.BIGQUERY_USE_LEGACY_SQL,572                          transform.source.use_legacy_sql)573        step.add_property(PropertyNames.BIGQUERY_FLATTEN_RESULTS,574                          transform.source.flatten_results)575      else:576        raise ValueError('BigQuery source %r must specify either a table or'577                         ' a query',578                         transform.source)579    elif transform.source.format == 'pubsub':580      standard_options = (581          transform_node.inputs[0].pipeline.options.view_as(StandardOptions))582      if not standard_options.streaming:583        raise ValueError('PubSubPayloadSource is currently available for use '584                         'only in streaming pipelines.')585      # Only one of topic or subscription should be set.586      if transform.source.full_subscription:587        step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION,588                          transform.source.full_subscription)589      elif transform.source.full_topic:590        step.add_property(PropertyNames.PUBSUB_TOPIC,591                          transform.source.full_topic)592      if transform.source.id_label:593        step.add_property(PropertyNames.PUBSUB_ID_LABEL,594                          transform.source.id_label)595    else:596      raise ValueError(597          'Source %r has unexpected format %s.' % (598              transform.source, transform.source.format))599    if not hasattr(transform.source, 'format'):600      step.add_property(PropertyNames.FORMAT, names.SOURCE_FORMAT)601    else:602      step.add_property(PropertyNames.FORMAT, transform.source.format)603    # Wrap coder in WindowedValueCoder: this is necessary as the encoding of a604    # step should be the type of value outputted by each step.  Read steps605    # automatically wrap output values in a WindowedValue wrapper, if necessary.606    # This is also necessary for proper encoding for size estimation.607    # Using a GlobalWindowCoder as a place holder instead of the default608    # PickleCoder because GlobalWindowCoder is known coder.609    # TODO(robertwb): Query the collection for the windowfn to extract the610    # correct coder.611    coder = coders.WindowedValueCoder(transform._infer_output_coder(),612                                      coders.coders.GlobalWindowCoder())  # pylint: disable=protected-access613    step.encoding = self._get_cloud_encoding(coder)614    step.add_property(615        PropertyNames.OUTPUT_INFO,616        [{PropertyNames.USER_NAME: (617            '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),618          PropertyNames.ENCODING: step.encoding,619          PropertyNames.OUTPUT_NAME: PropertyNames.OUT}])620  def run__NativeWrite(self, transform_node):621    transform = transform_node.transform622    input_tag = transform_node.inputs[0].tag623    input_step = self._cache.get_pvalue(transform_node.inputs[0])624    step = self._add_step(625        TransformNames.WRITE, transform_node.full_label, transform_node)626    # TODO(mairbek): refactor if-else tree to use registerable functions.627    # Initialize the sink specific properties.628    if transform.sink.format == 'text':629      # Note that it is important to use typed properties (@type/value dicts)630      # for non-string properties and also for empty strings. For example,631      # in the code below the num_shards must have type and also632      # file_name_suffix and shard_name_template (could be empty strings).633      step.add_property(634          PropertyNames.FILE_NAME_PREFIX, transform.sink.file_name_prefix,635          with_type=True)636      step.add_property(637          PropertyNames.FILE_NAME_SUFFIX, transform.sink.file_name_suffix,638          with_type=True)639      step.add_property(640          PropertyNames.SHARD_NAME_TEMPLATE, transform.sink.shard_name_template,641          with_type=True)642      if transform.sink.num_shards > 0:643        step.add_property(644            PropertyNames.NUM_SHARDS, transform.sink.num_shards, with_type=True)645      # TODO(silviuc): Implement sink validation.646      step.add_property(PropertyNames.VALIDATE_SINK, False, with_type=True)647    elif transform.sink.format == 'bigquery':648      # TODO(silviuc): Add table validation if transform.sink.validate.649      step.add_property(PropertyNames.BIGQUERY_DATASET,650                        transform.sink.table_reference.datasetId)651      step.add_property(PropertyNames.BIGQUERY_TABLE,652                        transform.sink.table_reference.tableId)653      # If project owning the table was not specified then the project owning654      # the workflow (current project) will be used.655      if transform.sink.table_reference.projectId is not None:656        step.add_property(PropertyNames.BIGQUERY_PROJECT,657                          transform.sink.table_reference.projectId)658      step.add_property(PropertyNames.BIGQUERY_CREATE_DISPOSITION,659                        transform.sink.create_disposition)660      step.add_property(PropertyNames.BIGQUERY_WRITE_DISPOSITION,661                        transform.sink.write_disposition)662      if transform.sink.table_schema is not None:663        step.add_property(664            PropertyNames.BIGQUERY_SCHEMA, transform.sink.schema_as_json())665    elif transform.sink.format == 'pubsub':666      standard_options = (667          transform_node.inputs[0].pipeline.options.view_as(StandardOptions))668      if not standard_options.streaming:669        raise ValueError('PubSubPayloadSink is currently available for use '670                         'only in streaming pipelines.')671      step.add_property(PropertyNames.PUBSUB_TOPIC, transform.sink.full_topic)672    else:673      raise ValueError(674          'Sink %r has unexpected format %s.' % (675              transform.sink, transform.sink.format))676    step.add_property(PropertyNames.FORMAT, transform.sink.format)677    # Wrap coder in WindowedValueCoder: this is necessary for proper encoding678    # for size estimation. Using a GlobalWindowCoder as a place holder instead679    # of the default PickleCoder because GlobalWindowCoder is known coder.680    # TODO(robertwb): Query the collection for the windowfn to extract the681    # correct coder.682    coder = coders.WindowedValueCoder(transform.sink.coder,683                                      coders.coders.GlobalWindowCoder())684    step.encoding = self._get_cloud_encoding(coder)685    step.add_property(PropertyNames.ENCODING, step.encoding)686    step.add_property(687        PropertyNames.PARALLEL_INPUT,688        {'@type': 'OutputReference',689         PropertyNames.STEP_NAME: input_step.proto.name,690         PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)})691  @classmethod692  def serialize_windowing_strategy(cls, windowing):693    from apache_beam.runners import pipeline_context694    from apache_beam.portability.api import beam_runner_api_pb2695    context = pipeline_context.PipelineContext()696    windowing_proto = windowing.to_runner_api(context)697    return cls.byte_array_to_json_string(698        beam_runner_api_pb2.MessageWithComponents(699            components=context.to_runner_api(),700            windowing_strategy=windowing_proto).SerializeToString())701  @classmethod702  def deserialize_windowing_strategy(cls, serialized_data):703    # Imported here to avoid circular dependencies.704    # pylint: disable=wrong-import-order, wrong-import-position705    from apache_beam.runners import pipeline_context706    from apache_beam.portability.api import beam_runner_api_pb2707    from apache_beam.transforms.core import Windowing708    proto = beam_runner_api_pb2.MessageWithComponents()709    proto.ParseFromString(cls.json_string_to_byte_array(serialized_data))710    return Windowing.from_runner_api(711        proto.windowing_strategy,712        pipeline_context.PipelineContext(proto.components))713  @staticmethod714  def byte_array_to_json_string(raw_bytes):715    """Implements org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString."""716    return urllib.quote(raw_bytes)717  @staticmethod718  def json_string_to_byte_array(encoded_string):719    """Implements org.apache.beam.sdk.util.StringUtils.jsonStringToByteArray."""720    return urllib.unquote(encoded_string)721class DataflowPipelineResult(PipelineResult):722  """Represents the state of a pipeline run on the Dataflow service."""723  def __init__(self, job, runner):724    """Job is a Job message from the Dataflow API."""725    self._job = job726    self._runner = runner727    self.metric_results = None728  def job_id(self):729    return self._job.id730  def metrics(self):731    return self.metric_results732  @property733  def has_job(self):734    return self._job is not None735  @property736  def state(self):737    """Return the current state of the remote job.738    Returns:739      A PipelineState object.740    """741    if not self.has_job:742      return PipelineState.UNKNOWN743    values_enum = dataflow_api.Job.CurrentStateValueValuesEnum744    api_jobstate_map = {745        values_enum.JOB_STATE_UNKNOWN: PipelineState.UNKNOWN,746        values_enum.JOB_STATE_STOPPED: PipelineState.STOPPED,747        values_enum.JOB_STATE_RUNNING: PipelineState.RUNNING,748        values_enum.JOB_STATE_DONE: PipelineState.DONE,749        values_enum.JOB_STATE_FAILED: PipelineState.FAILED,750        values_enum.JOB_STATE_CANCELLED: PipelineState.CANCELLED,751        values_enum.JOB_STATE_UPDATED: PipelineState.UPDATED,752        values_enum.JOB_STATE_DRAINING: PipelineState.DRAINING,753        values_enum.JOB_STATE_DRAINED: PipelineState.DRAINED,754    }755    return (api_jobstate_map[self._job.currentState] if self._job.currentState756            else PipelineState.UNKNOWN)757  def _is_in_terminal_state(self):758    if not self.has_job:759      return True760    return self.state in [761        PipelineState.STOPPED, PipelineState.DONE, PipelineState.FAILED,762        PipelineState.CANCELLED, PipelineState.DRAINED]763  def wait_until_finish(self, duration=None):764    if not self._is_in_terminal_state():765      if not self.has_job:766        raise IOError('Failed to get the Dataflow job id.')767      if duration:768        raise NotImplementedError(769            'DataflowRunner does not support duration argument.')770      thread = threading.Thread(771          target=DataflowRunner.poll_for_job_completion,772          args=(self._runner, self))773      # Mark the thread as a daemon thread so a keyboard interrupt on the main774      # thread will terminate everything. This is also the reason we will not775      # use thread.join() to wait for the polling thread.776      thread.daemon = True777      thread.start()778      while thread.isAlive():779        time.sleep(5.0)780      if self.state != PipelineState.DONE:781        # TODO(BEAM-1290): Consider converting this to an error log based on the782        # resolution of the issue.783        raise DataflowRuntimeException(784            'Dataflow pipeline failed. State: %s, Error:\n%s' %785            (self.state, getattr(self._runner, 'last_error_msg', None)), self)786    return self.state787  def __str__(self):788    return '<%s %s %s>' % (789        self.__class__.__name__,790        self.job_id(),791        self.state)792  def __repr__(self):793    return '<%s %s at %s>' % (self.__class__.__name__, self._job, hex(id(self)))794class DataflowRuntimeException(Exception):795  """Indicates an error has occurred in running this pipeline."""796  def __init__(self, msg, result):797    super(DataflowRuntimeException, self).__init__(msg)...pipeline.py
Source:pipeline.py  
1#2# Licensed to the Apache Software Foundation (ASF) under one or more3# contributor license agreements.  See the NOTICE file distributed with4# this work for additional information regarding copyright ownership.5# The ASF licenses this file to You under the Apache License, Version 2.06# (the "License"); you may not use this file except in compliance with7# the License.  You may obtain a copy of the License at8#9#    http://www.apache.org/licenses/LICENSE-2.010#11# Unless required by applicable law or agreed to in writing, software12# distributed under the License is distributed on an "AS IS" BASIS,13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.14# See the License for the specific language governing permissions and15# limitations under the License.16#17"""Pipeline, the top-level Dataflow object.18A pipeline holds a DAG of data transforms. Conceptually the nodes of the DAG19are transforms (PTransform objects) and the edges are values (mostly PCollection20objects). The transforms take as inputs one or more PValues and output one or21more PValues.22The pipeline offers functionality to traverse the graph.  The actual operation23to be executed for each node visited is specified through a runner object.24Typical usage:25  # Create a pipeline object using a local runner for execution.26  with beam.Pipeline('DirectRunner') as p:27    # Add to the pipeline a "Create" transform. When executed this28    # transform will produce a PCollection object with the specified values.29    pcoll = p | 'Create' >> beam.Create([1, 2, 3])30    # Another transform could be applied to pcoll, e.g., writing to a text file.31    # For other transforms, refer to transforms/ directory.32    pcoll | 'Write' >> beam.io.WriteToText('./output')33    # run() will execute the DAG stored in the pipeline.  The execution of the34    # nodes visited is done using the specified local runner.35"""36from __future__ import absolute_import37import abc38import collections39import logging40import os41import shutil42import tempfile43from apache_beam import pvalue44from apache_beam.internal import pickler45from apache_beam.pvalue import PCollection46from apache_beam.runners import create_runner47from apache_beam.runners import PipelineRunner48from apache_beam.transforms import ptransform49from apache_beam.typehints import typehints50from apache_beam.typehints import TypeCheckError51from apache_beam.options.pipeline_options import PipelineOptions52from apache_beam.options.pipeline_options import SetupOptions53from apache_beam.options.pipeline_options import StandardOptions54from apache_beam.options.pipeline_options import TypeOptions55from apache_beam.options.pipeline_options_validator import PipelineOptionsValidator56from apache_beam.utils.annotations import deprecated57__all__ = ['Pipeline']58class Pipeline(object):59  """A pipeline object that manages a DAG of PValues and their PTransforms.60  Conceptually the PValues are the DAG's nodes and the PTransforms computing61  the PValues are the edges.62  All the transforms applied to the pipeline must have distinct full labels.63  If same transform instance needs to be applied then the right shift operator64  should be used to designate new names (e.g. `input | "label" >> my_tranform`).65  """66  def __init__(self, runner=None, options=None, argv=None):67    """Initialize a pipeline object.68    Args:69      runner: An object of type 'PipelineRunner' that will be used to execute70        the pipeline. For registered runners, the runner name can be specified,71        otherwise a runner object must be supplied.72      options: A configured 'PipelineOptions' object containing arguments73        that should be used for running the Dataflow job.74      argv: a list of arguments (such as sys.argv) to be used for building a75        'PipelineOptions' object. This will only be used if argument 'options'76        is None.77    Raises:78      ValueError: if either the runner or options argument is not of the79      expected type.80    """81    if options is not None:82      if isinstance(options, PipelineOptions):83        self._options = options84      else:85        raise ValueError(86            'Parameter options, if specified, must be of type PipelineOptions. '87            'Received : %r', options)88    elif argv is not None:89      if isinstance(argv, list):90        self._options = PipelineOptions(argv)91      else:92        raise ValueError(93            'Parameter argv, if specified, must be a list. Received : %r', argv)94    else:95      self._options = PipelineOptions([])96    if runner is None:97      runner = self._options.view_as(StandardOptions).runner98      if runner is None:99        runner = StandardOptions.DEFAULT_RUNNER100        logging.info(('Missing pipeline option (runner). Executing pipeline '101                      'using the default runner: %s.'), runner)102    if isinstance(runner, str):103      runner = create_runner(runner)104    elif not isinstance(runner, PipelineRunner):105      raise TypeError('Runner must be a PipelineRunner object or the '106                      'name of a registered runner.')107    # Validate pipeline options108    errors = PipelineOptionsValidator(self._options, runner).validate()109    if errors:110      raise ValueError(111          'Pipeline has validations errors: \n' + '\n'.join(errors))112    # Default runner to be used.113    self.runner = runner114    # Stack of transforms generated by nested apply() calls. The stack will115    # contain a root node as an enclosing (parent) node for top transforms.116    self.transforms_stack = [AppliedPTransform(None, None, '', None)]117    # Set of transform labels (full labels) applied to the pipeline.118    # If a transform is applied and the full label is already in the set119    # then the transform will have to be cloned with a new label.120    self.applied_labels = set()121  @property122  @deprecated(since='First stable release',123              extra_message='References to <pipeline>.options'124              ' will not be supported')125  def options(self):126    return self._options127  def _current_transform(self):128    """Returns the transform currently on the top of the stack."""129    return self.transforms_stack[-1]130  def _root_transform(self):131    """Returns the root transform of the transform stack."""132    return self.transforms_stack[0]133  def _remove_labels_recursively(self, applied_transform):134    for part in applied_transform.parts:135      if part.full_label in self.applied_labels:136        self.applied_labels.remove(part.full_label)137      if part.parts:138        for part2 in part.parts:139          self._remove_labels_recursively(part2)140  def _replace(self, override):141    assert isinstance(override, PTransformOverride)142    matcher = override.get_matcher()143    output_map = {}144    output_replacements = {}145    input_replacements = {}146    class TransformUpdater(PipelineVisitor): # pylint: disable=used-before-assignment147      """"A visitor that replaces the matching PTransforms."""148      def __init__(self, pipeline):149        self.pipeline = pipeline150      def _replace_if_needed(self, transform_node):151        if matcher(transform_node):152          replacement_transform = override.get_replacement_transform(153              transform_node.transform)154          inputs = transform_node.inputs155          # TODO:  Support replacing PTransforms with multiple inputs.156          if len(inputs) > 1:157            raise NotImplementedError(158                'PTransform overriding is only supported for PTransforms that '159                'have a single input. Tried to replace input of '160                'AppliedPTransform %r that has %d inputs',161                transform_node, len(inputs))162          transform_node.transform = replacement_transform163          self.pipeline.transforms_stack.append(transform_node)164          # Keeping the same label for the replaced node but recursively165          # removing labels of child transforms since they will be replaced166          # during the expand below.167          self.pipeline._remove_labels_recursively(transform_node)168          new_output = replacement_transform.expand(inputs[0])169          if new_output.producer is None:170            # When current transform is a primitive, we set the producer here.171            new_output.producer = transform_node172          # We only support replacing transforms with a single output with173          # another transform that produces a single output.174          # TODO: Support replacing PTransforms with multiple outputs.175          if (len(transform_node.outputs) > 1 or176              not isinstance(transform_node.outputs[None], PCollection) or177              not isinstance(new_output, PCollection)):178            raise NotImplementedError(179                'PTransform overriding is only supported for PTransforms that '180                'have a single output. Tried to replace output of '181                'AppliedPTransform %r with %r.'182                , transform_node, new_output)183          # Recording updated outputs. This cannot be done in the same visitor184          # since if we dynamically update output type here, we'll run into185          # errors when visiting child nodes.186          output_map[transform_node.outputs[None]] = new_output187          self.pipeline.transforms_stack.pop()188      def enter_composite_transform(self, transform_node):189        self._replace_if_needed(transform_node)190      def visit_transform(self, transform_node):191        self._replace_if_needed(transform_node)192    self.visit(TransformUpdater(self))193    # Adjusting inputs and outputs194    class InputOutputUpdater(PipelineVisitor): # pylint: disable=used-before-assignment195      """"A visitor that records input and output values to be replaced.196      Input and output values that should be updated are recorded in maps197      input_replacements and output_replacements respectively.198      We cannot update input and output values while visiting since that results199      in validation errors.200      """201      def __init__(self, pipeline):202        self.pipeline = pipeline203      def enter_composite_transform(self, transform_node):204        self.visit_transform(transform_node)205      def visit_transform(self, transform_node):206        if (None in transform_node.outputs and207            transform_node.outputs[None] in output_map):208          output_replacements[transform_node] = (209              output_map[transform_node.outputs[None]])210        replace_input = False211        for input in transform_node.inputs:212          if input in output_map:213            replace_input = True214            break215        if replace_input:216          new_input = [217              input if not input in output_map else output_map[input]218              for input in transform_node.inputs]219          input_replacements[transform_node] = new_input220    self.visit(InputOutputUpdater(self))221    for transform in output_replacements:222      transform.replace_output(output_replacements[transform])223    for transform in input_replacements:224      transform.inputs = input_replacements[transform]225  def _check_replacement(self, override):226    matcher = override.get_matcher()227    class ReplacementValidator(PipelineVisitor):228      def visit_transform(self, transform_node):229        if matcher(transform_node):230          raise RuntimeError('Transform node %r was not replaced as expected.',231                             transform_node)232    self.visit(ReplacementValidator())233  def replace_all(self, replacements):234    """ Dynamically replaces PTransforms in the currently populated hierarchy.235     Currently this only works for replacements where input and output types236     are exactly the same.237     TODO: Update this to also work for transform overrides where input and238     output types are different.239    Args:240      replacements a list of PTransformOverride objects.241    """242    for override in replacements:243      assert isinstance(override, PTransformOverride)244      self._replace(override)245    # Checking if the PTransforms have been successfully replaced. This will246    # result in a failure if a PTransform that was replaced in a given override247    # gets re-added in a subsequent override. This is not allowed and ordering248    # of PTransformOverride objects in 'replacements' is important.249    for override in replacements:250      self._check_replacement(override)251  def run(self, test_runner_api=True):252    """Runs the pipeline. Returns whatever our runner returns after running."""253    # When possible, invoke a round trip through the runner API.254    if test_runner_api and self._verify_runner_api_compatible():255      return Pipeline.from_runner_api(256          self.to_runner_api(), self.runner, self._options).run(False)257    if self._options.view_as(SetupOptions).save_main_session:258      # If this option is chosen, verify we can pickle the main session early.259      tmpdir = tempfile.mkdtemp()260      try:261        pickler.dump_session(os.path.join(tmpdir, 'main_session.pickle'))262      finally:263        shutil.rmtree(tmpdir)264    return self.runner.run(self)265  def __enter__(self):266    return self267  def __exit__(self, exc_type, exc_val, exc_tb):268    if not exc_type:269      self.run().wait_until_finish()270  def visit(self, visitor):271    """Visits depth-first every node of a pipeline's DAG.272    Runner-internal implementation detail; no backwards-compatibility guarantees273    Args:274      visitor: PipelineVisitor object whose callbacks will be called for each275        node visited. See PipelineVisitor comments.276    Raises:277      TypeError: if node is specified and is not a PValue.278      pipeline.PipelineError: if node is specified and does not belong to this279        pipeline instance.280    """281    visited = set()282    self._root_transform().visit(visitor, self, visited)283  def apply(self, transform, pvalueish=None, label=None):284    """Applies a custom transform using the pvalueish specified.285    Args:286      transform: the PTranform to apply.287      pvalueish: the input for the PTransform (typically a PCollection).288      label: label of the PTransform.289    Raises:290      TypeError: if the transform object extracted from the argument list is291        not a PTransform.292      RuntimeError: if the transform object was already applied to this pipeline293        and needs to be cloned in order to apply again.294    """295    if isinstance(transform, ptransform._NamedPTransform):296      return self.apply(transform.transform, pvalueish,297                        label or transform.label)298    if not isinstance(transform, ptransform.PTransform):299      raise TypeError("Expected a PTransform object, got %s" % transform)300    if label:301      # Fix self.label as it is inspected by some PTransform operations302      # (e.g. to produce error messages for type hint violations).303      try:304        old_label, transform.label = transform.label, label305        return self.apply(transform, pvalueish)306      finally:307        transform.label = old_label308    full_label = '/'.join([self._current_transform().full_label,309                           label or transform.label]).lstrip('/')310    if full_label in self.applied_labels:311      raise RuntimeError(312          'Transform "%s" does not have a stable unique label. '313          'This will prevent updating of pipelines. '314          'To apply a transform with a specified label write '315          'pvalue | "label" >> transform'316          % full_label)317    self.applied_labels.add(full_label)318    pvalueish, inputs = transform._extract_input_pvalues(pvalueish)319    try:320      inputs = tuple(inputs)321      for leaf_input in inputs:322        if not isinstance(leaf_input, pvalue.PValue):323          raise TypeError324    except TypeError:325      raise NotImplementedError(326          'Unable to extract PValue inputs from %s; either %s does not accept '327          'inputs of this format, or it does not properly override '328          '_extract_input_pvalues' % (pvalueish, transform))329    current = AppliedPTransform(330        self._current_transform(), transform, full_label, inputs)331    self._current_transform().add_part(current)332    self.transforms_stack.append(current)333    type_options = self._options.view_as(TypeOptions)334    if type_options.pipeline_type_check:335      transform.type_check_inputs(pvalueish)336    pvalueish_result = self.runner.apply(transform, pvalueish)337    if type_options is not None and type_options.pipeline_type_check:338      transform.type_check_outputs(pvalueish_result)339    for result in ptransform.GetPValues().visit(pvalueish_result):340      assert isinstance(result, (pvalue.PValue, pvalue.DoOutputsTuple))341      # Make sure we set the producer only for a leaf node in the transform DAG.342      # This way we preserve the last transform of a composite transform as343      # being the real producer of the result.344      if result.producer is None:345        result.producer = current346      # TODO(robertwb): Multi-input, multi-output inference.347      # TODO(robertwb): Ideally we'd do intersection here.348      if (type_options is not None and type_options.pipeline_type_check349          and isinstance(result, pvalue.PCollection)350          and not result.element_type):351        input_element_type = (352            inputs[0].element_type353            if len(inputs) == 1354            else typehints.Any)355        type_hints = transform.get_type_hints()356        declared_output_type = type_hints.simple_output_type(transform.label)357        if declared_output_type:358          input_types = type_hints.input_types359          if input_types and input_types[0]:360            declared_input_type = input_types[0][0]361            result.element_type = typehints.bind_type_variables(362                declared_output_type,363                typehints.match_type_variables(declared_input_type,364                                               input_element_type))365          else:366            result.element_type = declared_output_type367        else:368          result.element_type = transform.infer_output_type(input_element_type)369      assert isinstance(result.producer.inputs, tuple)370      current.add_output(result)371    if (type_options is not None and372        type_options.type_check_strictness == 'ALL_REQUIRED' and373        transform.get_type_hints().output_types is None):374      ptransform_name = '%s(%s)' % (transform.__class__.__name__, full_label)375      raise TypeCheckError('Pipeline type checking is enabled, however no '376                           'output type-hint was found for the '377                           'PTransform %s' % ptransform_name)378    current.update_input_refcounts()379    self.transforms_stack.pop()380    return pvalueish_result381  def __reduce__(self):382    # Some transforms contain a reference to their enclosing pipeline,383    # which in turn reference all other transforms (resulting in quadratic384    # time/space to pickle each transform individually).  As we don't385    # require pickled pipelines to be executable, break the chain here.386    return str, ('Pickled pipeline stub.',)387  def _verify_runner_api_compatible(self):388    class Visitor(PipelineVisitor):  # pylint: disable=used-before-assignment389      ok = True  # Really a nonlocal.390      def enter_composite_transform(self, transform_node):391        self.visit_transform(transform_node)392      def visit_transform(self, transform_node):393        if transform_node.side_inputs:394          # No side inputs (yet).395          Visitor.ok = False396        try:397          # Transforms must be picklable.398          pickler.loads(pickler.dumps(transform_node.transform,399                                      enable_trace=False),400                        enable_trace=False)401        except Exception:402          Visitor.ok = False403      def visit_value(self, value, _):404        if isinstance(value, pvalue.PDone):405          Visitor.ok = False406    self.visit(Visitor())407    return Visitor.ok408  def to_runner_api(self):409    """For internal use only; no backwards-compatibility guarantees."""410    from apache_beam.runners import pipeline_context411    from apache_beam.portability.api import beam_runner_api_pb2412    context = pipeline_context.PipelineContext()413    # Mutates context; placing inline would force dependence on414    # argument evaluation order.415    root_transform_id = context.transforms.get_id(self._root_transform())416    proto = beam_runner_api_pb2.Pipeline(417        root_transform_ids=[root_transform_id],418        components=context.to_runner_api())419    return proto420  @staticmethod421  def from_runner_api(proto, runner, options):422    """For internal use only; no backwards-compatibility guarantees."""423    p = Pipeline(runner=runner, options=options)424    from apache_beam.runners import pipeline_context425    context = pipeline_context.PipelineContext(proto.components)426    root_transform_id, = proto.root_transform_ids427    p.transforms_stack = [428        context.transforms.get_by_id(root_transform_id)]429    # TODO(robertwb): These are only needed to continue construction. Omit?430    p.applied_labels = set([431        t.unique_name for t in proto.components.transforms.values()])432    for id in proto.components.pcollections:433      pcollection = context.pcollections.get_by_id(id)434      pcollection.pipeline = p435    # Inject PBegin input where necessary.436    from apache_beam.io.iobase import Read437    from apache_beam.transforms.core import Create438    has_pbegin = [Read, Create]439    for id in proto.components.transforms:440      transform = context.transforms.get_by_id(id)441      if not transform.inputs and transform.transform.__class__ in has_pbegin:442        transform.inputs = (pvalue.PBegin(p),)443    return p444class PipelineVisitor(object):445  """For internal use only; no backwards-compatibility guarantees.446  Visitor pattern class used to traverse a DAG of transforms447  (used internally by Pipeline for bookeeping purposes).448  """449  def visit_value(self, value, producer_node):450    """Callback for visiting a PValue in the pipeline DAG.451    Args:452      value: PValue visited (typically a PCollection instance).453      producer_node: AppliedPTransform object whose transform produced the454        pvalue.455    """456    pass457  def visit_transform(self, transform_node):458    """Callback for visiting a transform leaf node in the pipeline DAG."""459    pass460  def enter_composite_transform(self, transform_node):461    """Callback for entering traversal of a composite transform node."""462    pass463  def leave_composite_transform(self, transform_node):464    """Callback for leaving traversal of a composite transform node."""465    pass466class AppliedPTransform(object):467  """For internal use only; no backwards-compatibility guarantees.468  A transform node representing an instance of applying a PTransform469  (used internally by Pipeline for bookeeping purposes).470  """471  def __init__(self, parent, transform, full_label, inputs):472    self.parent = parent473    self.transform = transform474    # Note that we want the PipelineVisitor classes to use the full_label,475    # inputs, side_inputs, and outputs fields from this instance instead of the476    # ones of the PTransform instance associated with it. Doing this permits477    # reusing PTransform instances in different contexts (apply() calls) without478    # any interference. This is particularly useful for composite transforms.479    self.full_label = full_label480    self.inputs = inputs or ()481    self.side_inputs = () if transform is None else tuple(transform.side_inputs)482    self.outputs = {}483    self.parts = []484    # Per tag refcount dictionary for PValues for which this node is a485    # root producer.486    self.refcounts = collections.defaultdict(int)487  def __repr__(self):488    return "%s(%s, %s)" % (self.__class__.__name__, self.full_label,489                           type(self.transform).__name__)490  def update_input_refcounts(self):491    """Increment refcounts for all transforms providing inputs."""492    def real_producer(pv):493      real = pv.producer494      while real.parts:495        real = real.parts[-1]496      return real497    if not self.is_composite():498      for main_input in self.inputs:499        if not isinstance(main_input, pvalue.PBegin):500          real_producer(main_input).refcounts[main_input.tag] += 1501      for side_input in self.side_inputs:502        real_producer(side_input.pvalue).refcounts[side_input.pvalue.tag] += 1503  def replace_output(self, output, tag=None):504    """Replaces the output defined by the given tag with the given output.505    Args:506      output: replacement output507      tag: tag of the output to be replaced.508    """509    if isinstance(output, pvalue.DoOutputsTuple):510      self.replace_output(output[output._main_tag])511    elif isinstance(output, pvalue.PValue):512      self.outputs[tag] = output513    else:514      raise TypeError("Unexpected output type: %s" % output)515  def add_output(self, output, tag=None):516    if isinstance(output, pvalue.DoOutputsTuple):517      self.add_output(output[output._main_tag])518    elif isinstance(output, pvalue.PValue):519      # TODO(BEAM-1833): Require tags when calling this method.520      if tag is None and None in self.outputs:521        tag = len(self.outputs)522      assert tag not in self.outputs523      self.outputs[tag] = output524    else:525      raise TypeError("Unexpected output type: %s" % output)526  def add_part(self, part):527    assert isinstance(part, AppliedPTransform)528    self.parts.append(part)529  def is_composite(self):530    """Returns whether this is a composite transform.531    A composite transform has parts (inner transforms) or isn't the532    producer for any of its outputs. (An example of a transform that533    is not a producer is one that returns its inputs instead.)534    """535    return bool(self.parts) or all(536        pval.producer is not self for pval in self.outputs.values())537  def visit(self, visitor, pipeline, visited):538    """Visits all nodes reachable from the current node."""539    for pval in self.inputs:540      if pval not in visited and not isinstance(pval, pvalue.PBegin):541        assert pval.producer is not None542        pval.producer.visit(visitor, pipeline, visited)543        # The value should be visited now since we visit outputs too.544        assert pval in visited, pval545    # Visit side inputs.546    for pval in self.side_inputs:547      if isinstance(pval, pvalue.AsSideInput) and pval.pvalue not in visited:548        pval = pval.pvalue  # Unpack marker-object-wrapped pvalue.549        assert pval.producer is not None550        pval.producer.visit(visitor, pipeline, visited)551        # The value should be visited now since we visit outputs too.552        assert pval in visited553        # TODO(silviuc): Is there a way to signal that we are visiting a side554        # value? The issue is that the same PValue can be reachable through555        # multiple paths and therefore it is not guaranteed that the value556        # will be visited as a side value.557    # Visit a composite or primitive transform.558    if self.is_composite():559      visitor.enter_composite_transform(self)560      for part in self.parts:561        part.visit(visitor, pipeline, visited)562      visitor.leave_composite_transform(self)563    else:564      visitor.visit_transform(self)565    # Visit the outputs (one or more). It is essential to mark as visited the566    # tagged PCollections of the DoOutputsTuple object. A tagged PCollection is567    # connected directly with its producer (a multi-output ParDo), but the568    # output of such a transform is the containing DoOutputsTuple, not the569    # PCollection inside it. Without the code below a tagged PCollection will570    # not be marked as visited while visiting its producer.571    for pval in self.outputs.values():572      if isinstance(pval, pvalue.DoOutputsTuple):573        pvals = (v for v in pval)574      else:575        pvals = (pval,)576      for v in pvals:577        if v not in visited:578          visited.add(v)579          visitor.visit_value(v, self)580  def named_inputs(self):581    # TODO(BEAM-1833): Push names up into the sdk construction.582    return {str(ix): input for ix, input in enumerate(self.inputs)583            if isinstance(input, pvalue.PCollection)}584  def named_outputs(self):585    return {str(tag): output for tag, output in self.outputs.items()586            if isinstance(output, pvalue.PCollection)}587  def to_runner_api(self, context):588    from apache_beam.portability.api import beam_runner_api_pb2589    def transform_to_runner_api(transform, context):590      if transform is None:591        return None592      else:593        return transform.to_runner_api(context)594    return beam_runner_api_pb2.PTransform(595        unique_name=self.full_label,596        spec=transform_to_runner_api(self.transform, context),597        subtransforms=[context.transforms.get_id(part) for part in self.parts],598        # TODO(BEAM-115): Side inputs.599        inputs={tag: context.pcollections.get_id(pc)600                for tag, pc in self.named_inputs().items()},601        outputs={str(tag): context.pcollections.get_id(out)602                 for tag, out in self.named_outputs().items()},603        # TODO(BEAM-115): display_data604        display_data=None)605  @staticmethod606  def from_runner_api(proto, context):607    result = AppliedPTransform(608        parent=None,609        transform=ptransform.PTransform.from_runner_api(proto.spec, context),610        full_label=proto.unique_name,611        inputs=[612            context.pcollections.get_by_id(id) for id in proto.inputs.values()])613    result.parts = [614        context.transforms.get_by_id(id) for id in proto.subtransforms]615    result.outputs = {616        None if tag == 'None' else tag: context.pcollections.get_by_id(id)617        for tag, id in proto.outputs.items()}618    if not result.parts:619      for tag, pc in result.outputs.items():620        if pc not in result.inputs:621          pc.producer = result622          pc.tag = tag623    result.update_input_refcounts()624    return result625class PTransformOverride(object):626  """For internal use only; no backwards-compatibility guarantees.627  Gives a matcher and replacements for matching PTransforms.628  TODO: Update this to support cases where input and/our output types are629  different.630  """631  __metaclass__ = abc.ABCMeta632  @abc.abstractmethod633  def get_matcher(self):634    """Gives a matcher that will be used to to perform this override.635    Returns:636      a callable that takes an AppliedPTransform as a parameter and returns a637      boolean as a result.638    """639    raise NotImplementedError640  @abc.abstractmethod641  def get_replacement_transform(self, ptransform):642    """Provides a runner specific override for a given PTransform.643    Args:644      ptransform: PTransform to be replaced.645    Returns:646      A PTransform that will be the replacement for the PTransform given as an647      argument.648    """649    # Returns a PTransformReplacement...maptask_executor_runner.py
Source:maptask_executor_runner.py  
1#2# Licensed to the Apache Software Foundation (ASF) under one or more3# contributor license agreements.  See the NOTICE file distributed with4# this work for additional information regarding copyright ownership.5# The ASF licenses this file to You under the Apache License, Version 2.06# (the "License"); you may not use this file except in compliance with7# the License.  You may obtain a copy of the License at8#9#    http://www.apache.org/licenses/LICENSE-2.010#11# Unless required by applicable law or agreed to in writing, software12# distributed under the License is distributed on an "AS IS" BASIS,13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.14# See the License for the specific language governing permissions and15# limitations under the License.16#17"""Beam runner for testing/profiling worker code directly.18"""19import collections20import logging21import time22import apache_beam as beam23from apache_beam.internal import pickler24from apache_beam.io import iobase25from apache_beam.metrics.execution import MetricsEnvironment26from apache_beam.options import pipeline_options27from apache_beam.runners import DataflowRunner28from apache_beam.runners.dataflow.internal.dependency import _dependency_file_copy29from apache_beam.runners.dataflow.internal.names import PropertyNames30from apache_beam.runners.dataflow.native_io.iobase import NativeSource31from apache_beam.runners.runner import PipelineResult32from apache_beam.runners.runner import PipelineRunner33from apache_beam.runners.runner import PipelineState34from apache_beam.runners.worker import operation_specs35from apache_beam.runners.worker import operations36try:37  from apache_beam.runners.worker import statesampler38except ImportError:39  from apache_beam.runners.worker import statesampler_fake as statesampler40from apache_beam.typehints import typehints41from apache_beam.utils import profiler42from apache_beam.utils.counters import CounterFactory43# This module is experimental. No backwards-compatibility guarantees.44class MapTaskExecutorRunner(PipelineRunner):45  """Beam runner translating a pipeline into map tasks that are then executed.46  Primarily intended for testing and profiling the worker code paths.47  """48  def __init__(self):49    self.executors = []50  def has_metrics_support(self):51    """Returns whether this runner supports metrics or not.52    """53    return False54  def run(self, pipeline):55    MetricsEnvironment.set_metrics_supported(self.has_metrics_support())56    # List of map tasks  Each map task is a list of57    # (stage_name, operation_specs.WorkerOperation) instructions.58    self.map_tasks = []59    # Map of pvalues to60    # (map_task_index, producer_operation_index, producer_output_index)61    self.outputs = {}62    # Unique mappings of PCollections to strings.63    self.side_input_labels = collections.defaultdict(64        lambda: str(len(self.side_input_labels)))65    # Mapping of map task indices to all map tasks that must preceed them.66    self.dependencies = collections.defaultdict(set)67    # Visit the graph, building up the map_tasks and their metadata.68    super(MapTaskExecutorRunner, self).run(pipeline)69    # Now run the tasks in topological order.70    def compute_depth_map(deps):71      memoized = {}72      def compute_depth(x):73        if x not in memoized:74          memoized[x] = 1 + max([-1] + [compute_depth(y) for y in deps[x]])75        return memoized[x]76      return {x: compute_depth(x) for x in deps.keys()}77    map_task_depths = compute_depth_map(self.dependencies)78    ordered_map_tasks = sorted((map_task_depths.get(ix, -1), map_task)79                               for ix, map_task in enumerate(self.map_tasks))80    profile_options = pipeline.options.view_as(81        pipeline_options.ProfilingOptions)82    if profile_options.profile_cpu:83      with profiler.Profile(84          profile_id='worker-runner',85          profile_location=profile_options.profile_location,86          log_results=True, file_copy_fn=_dependency_file_copy):87        self.execute_map_tasks(ordered_map_tasks)88    else:89      self.execute_map_tasks(ordered_map_tasks)90    return WorkerRunnerResult(PipelineState.UNKNOWN)91  def metrics_containers(self):92    return [op.metrics_container93            for ex in self.executors94            for op in ex.operations()]95  def execute_map_tasks(self, ordered_map_tasks):96    tt = time.time()97    for ix, (_, map_task) in enumerate(ordered_map_tasks):98      logging.info('Running %s', map_task)99      t = time.time()100      stage_names, all_operations = zip(*map_task)101      # TODO(robertwb): The DataflowRunner worker receives system step names102      # (e.g. "s3") that are used to label the output msec counters.  We use the103      # operation names here, but this is not the same scheme used by the104      # DataflowRunner; the result is that the output msec counters are named105      # differently.106      system_names = stage_names107      # Create the CounterFactory and StateSampler for this MapTask.108      # TODO(robertwb): Output counters produced here are currently ignored.109      counter_factory = CounterFactory()110      state_sampler = statesampler.StateSampler('%s-' % ix, counter_factory)111      map_executor = operations.SimpleMapTaskExecutor(112          operation_specs.MapTask(113              all_operations, 'S%02d' % ix,114              system_names, stage_names, system_names),115          counter_factory,116          state_sampler)117      self.executors.append(map_executor)118      map_executor.execute()119      logging.info(120          'Stage %s finished: %0.3f sec', stage_names[0], time.time() - t)121    logging.info('Total time: %0.3f sec', time.time() - tt)122  def run_Read(self, transform_node):123    self._run_read_from(transform_node, transform_node.transform.source)124  def _run_read_from(self, transform_node, source):125    """Used when this operation is the result of reading source."""126    if not isinstance(source, NativeSource):127      source = iobase.SourceBundle(1.0, source, None, None)128    output = transform_node.outputs[None]129    element_coder = self._get_coder(output)130    read_op = operation_specs.WorkerRead(source, output_coders=[element_coder])131    self.outputs[output] = len(self.map_tasks), 0, 0132    self.map_tasks.append([(transform_node.full_label, read_op)])133    return len(self.map_tasks) - 1134  def run_ParDo(self, transform_node):135    transform = transform_node.transform136    output = transform_node.outputs[None]137    element_coder = self._get_coder(output)138    map_task_index, producer_index, output_index = self.outputs[139        transform_node.inputs[0]]140    # If any of this ParDo's side inputs depend on outputs from this map_task,141    # we can't continue growing this map task.142    def is_reachable(leaf, root):143      if leaf == root:144        return True145      else:146        return any(is_reachable(x, root) for x in self.dependencies[leaf])147    if any(is_reachable(self.outputs[side_input.pvalue][0], map_task_index)148           for side_input in transform_node.side_inputs):149      # Start a new map tasks.150      input_element_coder = self._get_coder(transform_node.inputs[0])151      output_buffer = OutputBuffer(input_element_coder)152      fusion_break_write = operation_specs.WorkerInMemoryWrite(153          output_buffer=output_buffer,154          write_windowed_values=True,155          input=(producer_index, output_index),156          output_coders=[input_element_coder])157      self.map_tasks[map_task_index].append(158          (transform_node.full_label + '/Write', fusion_break_write))159      original_map_task_index = map_task_index160      map_task_index, producer_index, output_index = len(self.map_tasks), 0, 0161      fusion_break_read = operation_specs.WorkerRead(162          output_buffer.source_bundle(),163          output_coders=[input_element_coder])164      self.map_tasks.append(165          [(transform_node.full_label + '/Read', fusion_break_read)])166      self.dependencies[map_task_index].add(original_map_task_index)167    def create_side_read(side_input):168      label = self.side_input_labels[side_input]169      output_buffer = self.run_side_write(170          side_input.pvalue, '%s/%s' % (transform_node.full_label, label))171      return operation_specs.WorkerSideInputSource(172          output_buffer.source(), label)173    do_op = operation_specs.WorkerDoFn(  #174        serialized_fn=pickler.dumps(DataflowRunner._pardo_fn_data(175            transform_node,176            lambda side_input: self.side_input_labels[side_input])),177        output_tags=[PropertyNames.OUT] + ['%s_%s' % (PropertyNames.OUT, tag)178                                           for tag in transform.output_tags179                                          ],180        # Same assumption that DataflowRunner has about coders being compatible181        # across outputs.182        output_coders=[element_coder] * (len(transform.output_tags) + 1),183        input=(producer_index, output_index),184        side_inputs=[create_side_read(side_input)185                     for side_input in transform_node.side_inputs])186    producer_index = len(self.map_tasks[map_task_index])187    self.outputs[transform_node.outputs[None]] = (188        map_task_index, producer_index, 0)189    for ix, tag in enumerate(transform.output_tags):190      self.outputs[transform_node.outputs[191          tag]] = map_task_index, producer_index, ix + 1192    self.map_tasks[map_task_index].append((transform_node.full_label, do_op))193    for side_input in transform_node.side_inputs:194      self.dependencies[map_task_index].add(self.outputs[side_input.pvalue][0])195  def run_side_write(self, pcoll, label):196    map_task_index, producer_index, output_index = self.outputs[pcoll]197    windowed_element_coder = self._get_coder(pcoll)198    output_buffer = OutputBuffer(windowed_element_coder)199    write_sideinput_op = operation_specs.WorkerInMemoryWrite(200        output_buffer=output_buffer,201        write_windowed_values=True,202        input=(producer_index, output_index),203        output_coders=[windowed_element_coder])204    self.map_tasks[map_task_index].append(205        (label, write_sideinput_op))206    return output_buffer207  def run__GroupByKeyOnly(self, transform_node):208    map_task_index, producer_index, output_index = self.outputs[209        transform_node.inputs[0]]210    grouped_element_coder = self._get_coder(transform_node.outputs[None],211                                            windowed=False)212    windowed_ungrouped_element_coder = self._get_coder(transform_node.inputs[0])213    output_buffer = GroupingOutputBuffer(grouped_element_coder)214    shuffle_write = operation_specs.WorkerInMemoryWrite(215        output_buffer=output_buffer,216        write_windowed_values=False,217        input=(producer_index, output_index),218        output_coders=[windowed_ungrouped_element_coder])219    self.map_tasks[map_task_index].append(220        (transform_node.full_label + '/Write', shuffle_write))221    output_map_task_index = self._run_read_from(222        transform_node, output_buffer.source())223    self.dependencies[output_map_task_index].add(map_task_index)224  def run_Flatten(self, transform_node):225    output_buffer = OutputBuffer(self._get_coder(transform_node.outputs[None]))226    output_map_task = self._run_read_from(transform_node,227                                          output_buffer.source())228    for input in transform_node.inputs:229      map_task_index, producer_index, output_index = self.outputs[input]230      element_coder = self._get_coder(input)231      flatten_write = operation_specs.WorkerInMemoryWrite(232          output_buffer=output_buffer,233          write_windowed_values=True,234          input=(producer_index, output_index),235          output_coders=[element_coder])236      self.map_tasks[map_task_index].append(237          (transform_node.full_label + '/Write', flatten_write))238      self.dependencies[output_map_task].add(map_task_index)239  def apply_CombinePerKey(self, transform, input):240    # TODO(robertwb): Support side inputs.241    assert not transform.args and not transform.kwargs242    return (input243            | PartialGroupByKeyCombineValues(transform.fn)244            | beam.GroupByKey()245            | MergeAccumulators(transform.fn)246            | ExtractOutputs(transform.fn))247  def run_PartialGroupByKeyCombineValues(self, transform_node):248    element_coder = self._get_coder(transform_node.outputs[None])249    _, producer_index, output_index = self.outputs[transform_node.inputs[0]]250    combine_op = operation_specs.WorkerPartialGroupByKey(251        combine_fn=pickler.dumps(252            (transform_node.transform.combine_fn, (), {}, ())),253        output_coders=[element_coder],254        input=(producer_index, output_index))255    self._run_as_op(transform_node, combine_op)256  def run_MergeAccumulators(self, transform_node):257    self._run_combine_transform(transform_node, 'merge')258  def run_ExtractOutputs(self, transform_node):259    self._run_combine_transform(transform_node, 'extract')260  def _run_combine_transform(self, transform_node, phase):261    transform = transform_node.transform262    element_coder = self._get_coder(transform_node.outputs[None])263    _, producer_index, output_index = self.outputs[transform_node.inputs[0]]264    combine_op = operation_specs.WorkerCombineFn(265        serialized_fn=pickler.dumps(266            (transform.combine_fn, (), {}, ())),267        phase=phase,268        output_coders=[element_coder],269        input=(producer_index, output_index))270    self._run_as_op(transform_node, combine_op)271  def _get_coder(self, pvalue, windowed=True):272    # TODO(robertwb): This should be an attribute of the pvalue itself.273    return DataflowRunner._get_coder(274        pvalue.element_type or typehints.Any,275        pvalue.windowing.windowfn.get_window_coder() if windowed else None)276  def _run_as_op(self, transform_node, op):277    """Single-output operation in the same map task as its input."""278    map_task_index, _, _ = self.outputs[transform_node.inputs[0]]279    op_index = len(self.map_tasks[map_task_index])280    output = transform_node.outputs[None]281    self.outputs[output] = map_task_index, op_index, 0282    self.map_tasks[map_task_index].append((transform_node.full_label, op))283class InMemorySource(iobase.BoundedSource):284  """Source for reading an (as-yet unwritten) set of in-memory encoded elements.285  """286  def __init__(self, encoded_elements, coder):287    self._encoded_elements = encoded_elements288    self._coder = coder289  def get_range_tracker(self, unused_start_position, unused_end_position):290    return None291  def read(self, unused_range_tracker):292    for encoded_element in self._encoded_elements:293      yield self._coder.decode(encoded_element)294  def default_output_coder(self):295    return self._coder296class OutputBuffer(object):297  def __init__(self, coder):298    self.coder = coder299    self.elements = []300    self.encoded_elements = []301  def source(self):302    return InMemorySource(self.encoded_elements, self.coder)303  def source_bundle(self):304    return iobase.SourceBundle(305        1.0, InMemorySource(self.encoded_elements, self.coder), None, None)306  def __repr__(self):307    return 'GroupingOutput[%r]' % len(self.elements)308  def append(self, value):309    self.elements.append(value)310    self.encoded_elements.append(self.coder.encode(value))311class GroupingOutputBuffer(object):312  def __init__(self, grouped_coder):313    self.grouped_coder = grouped_coder314    self.elements = collections.defaultdict(list)315    self.frozen = False316  def source(self):317    return InMemorySource(self.encoded_elements, self.grouped_coder)318  def __repr__(self):319    return 'GroupingOutputBuffer[%r]' % len(self.elements)320  def append(self, pair):321    assert not self.frozen322    k, v = pair323    self.elements[k].append(v)324  def freeze(self):325    if not self.frozen:326      self._encoded_elements = [self.grouped_coder.encode(kv)327                                for kv in self.elements.iteritems()]328    self.frozen = True329    return self._encoded_elements330  @property331  def encoded_elements(self):332    return GroupedOutputBuffer(self)333class GroupedOutputBuffer(object):334  def __init__(self, buffer):335    self.buffer = buffer336  def __getitem__(self, ix):337    return self.buffer.freeze()[ix]338  def __iter__(self):339    return iter(self.buffer.freeze())340  def __len__(self):341    return len(self.buffer.freeze())342  def __nonzero__(self):343    return True344class PartialGroupByKeyCombineValues(beam.PTransform):345  def __init__(self, combine_fn, native=True):346    self.combine_fn = combine_fn347    self.native = native348  def expand(self, input):349    if self.native:350      return beam.pvalue.PCollection(input.pipeline)351    else:352      def to_accumulator(v):353        return self.combine_fn.add_input(354            self.combine_fn.create_accumulator(), v)355      return input | beam.Map(lambda (k, v): (k, to_accumulator(v)))356class MergeAccumulators(beam.PTransform):357  def __init__(self, combine_fn, native=True):358    self.combine_fn = combine_fn359    self.native = native360  def expand(self, input):361    if self.native:362      return beam.pvalue.PCollection(input.pipeline)363    else:364      merge_accumulators = self.combine_fn.merge_accumulators365      return input | beam.Map(lambda (k, vs): (k, merge_accumulators(vs)))366class ExtractOutputs(beam.PTransform):367  def __init__(self, combine_fn, native=True):368    self.combine_fn = combine_fn369    self.native = native370  def expand(self, input):371    if self.native:372      return beam.pvalue.PCollection(input.pipeline)373    else:374      extract_output = self.combine_fn.extract_output375      return input | beam.Map(lambda (k, v): (k, extract_output(v)))376class WorkerRunnerResult(PipelineResult):377  def wait_until_finish(self, duration=None):...client.py
Source:client.py  
1"""Module representing the client user of OpenMetroGuide.2Client is capable of calculating the path that needs to be taken from one3station to another. Can use Distance/Cost as the calculation constraints.4"""5import sys6from typing import Optional7import pygame8from pygame.colordict import THECOLORS9from src.Display.Utils.general_utils import WIDTH, in_circle, PALETTE_WIDTH, \10    BLACK, WHITE, draw_text, HEIGHT11from src.Base.map import Map12from src.Base.node import Node13from src.Display.Canvas.user import User14class Client(User):15    """Client is the aspect of the User which displays a Map object on the screen,16    and then uses pygame mouse click event objects to determine the Client's17    starting point and final destination, and the variable of optimization the Client18    prefers. Then, the best possible route between the two stations is highlighted.19    Instance Attributes:20        - metro_map: Refers to the current metro transit map being used by the client to21        locate the stations and find the optimized path as per requirement.22    """23    metro_map: Map24    _start: Optional[Node]25    _end: Optional[Node]26    def __init__(self, input_map: Map, city_name: str) -> None:27        """ Initializes the Instance Attributes of28        the Client class which is a child of User.29        """30        super(Client, self).__init__('distance', city_name)31        self.metro_map = input_map32        for node in self.metro_map.get_all_nodes():33            for neighbor in node.get_neighbours():34                node.update_weights(neighbor)35        self._start = None36        self._end = None37    def handle_mouse_click(self, event: pygame.event.Event,38                           screen_size: tuple[int, int]) -> None:39        """ Handle a mouse click event.40        A pygame mouse click event object has two attributes that are important for this method:41            - event.pos: the (x, y) coordinates of the mouse click42            - event.button: an int representing which mouse button was clicked.43                            1: left-click, 3: right-click44        The screen_size is a tuple of (width, height), and should be used together with45        event.pos to determine which cell is being clicked.46        If the click is within the area of the palette, then check if it is within the option of47        distance or cost and handle accordingly.48        If the click is within the grid, check if the click is left or right.49        If it is the left click, this marks the starting station. If the click is a right click50        this marks the destination station. The right click is only possible if a starting station51        has already been selected.52        Preconditions:53            - event.type == pygame.MOUSEBUTTONDOWN54            - screen_size[0] >= ...55            - screen_size[1] >= ...56        """57        pygame.init()58        click_coordinates = self.get_click_pos(event)59        if event.pos[0] > WIDTH:  # The click is on the palette60            for option in self.opt_to_center:61                target = self.opt_to_center[option]62                input_rect = pygame.Rect(target[0], target[1], 35, 35)63                if input_rect.collidepoint(event.pos):64                    self._curr_opt = option65        else:  # The click is on the map.66            station = self.node_exists(click_coordinates)67            if event.button == 1:68                self._start = station69            elif event.button == 3:70                self._end = station71        return72    def _connect_final_route(self, path: list[str]) -> None:73        """Displays the final path highlighting the tracks being used,74        making the others gray.75        """76        lst = [n for n in self.metro_map.get_all_nodes('') if n.name not in path]77        for i in range(0, len(path) - 1):78            node = self.metro_map.get_node(path[i])79            transform_node = self.scale_factor_transformations(node.coordinates)80            for neighbours in node.get_neighbours():81                if neighbours.name == path[i + 1]:82                    color = node.get_color(neighbours)83                    transform_neighbor = self.scale_factor_transformations(neighbours.coordinates)84                    if transform_neighbor[0] <= WIDTH and transform_node[0] <= WIDTH:85                        pygame.draw.line(surface=self._screen,86                                         color=color,87                                         start_pos=transform_node,88                                         end_pos=transform_neighbor,89                                         width=5)90        for node in lst:91            transform_node = self.scale_factor_transformations(node.coordinates)92            for neighbours in node.get_neighbours():93                transform_neighbor = self.scale_factor_transformations(neighbours.coordinates)94                if transform_neighbor[0] <= WIDTH and transform_node[0] <= WIDTH:95                    pygame.draw.line(surface=self._screen,96                                     color=THECOLORS['gray50'],97                                     start_pos=transform_node,98                                     end_pos=transform_neighbor,99                                     width=3)100        return101    def create_palette(self) -> None:102        """ Draw the palette which contains the images103        representing distance and cost for the client104        to choose as per their requirement.105        """106        rect_width = (PALETTE_WIDTH // 4)107        ht = PALETTE_WIDTH * 6108        image1 = pygame.image.load('../Assets/distance.png')109        image_distance = pygame.transform.scale(image1, (30, 30))110        image2 = pygame.image.load('../Assets/cost.png')111        image_cost = pygame.transform.scale(image2, (30, 30))112        self._screen.blit(image_distance, (WIDTH + rect_width, ht))113        self.opt_to_center['distance'] = (WIDTH + rect_width, ht)114        self._screen.blit(image_cost, (WIDTH + rect_width, 2 * ht - 50))115        self.opt_to_center['cost'] = (WIDTH + rect_width, 2 * ht - 50)116    def set_selection(self, palette_choice: str) -> None:117        """Darkens the borders of the selected118        optimization from the palette provided.119        """120        target = self.opt_to_center[palette_choice]121        input_rect = pygame.Rect(target[0] - 2, target[1], 35, 35)122        pygame.draw.rect(self._screen, BLACK, input_rect, 3)123    def display(self) -> None:124        """Performs the display of the screen for a Client."""125        while True:126            self._screen.fill(WHITE)127            self.draw_grid()128            self.create_palette()129            self.set_selection(self._curr_opt)130            visited = set()131            for node in self.metro_map.get_all_nodes('station'):132                transform_node = self.scale_factor_transformations(node.coordinates)133                if 0 < transform_node[0] <= 800 and 0 < transform_node[1] < 800:134                    pygame.draw.circle(self._screen, BLACK,135                                       transform_node, 5)136            for event in pygame.event.get():137                if event.type == pygame.QUIT:138                    sys.exit()139                elif event.type == pygame.MOUSEBUTTONDOWN:140                    self.handle_mouse_click(event, (WIDTH, HEIGHT))141                elif event.type == pygame.KEYUP:142                    if event.key == pygame.K_DOWN:143                        self.handle_d_shift()144                    elif event.key == pygame.K_UP:145                        self.handle_u_shift()146                    elif event.key == pygame.K_LEFT:147                        self.handle_l_shift()148                    elif event.key == pygame.K_RIGHT:149                        self.handle_r_shift()150                    elif event.key == pygame.K_p and pygame.key.get_mods() & pygame.KMOD_CTRL:151                        self.handle_zoom_in()152                    elif event.key == pygame.K_m and pygame.key.get_mods() & pygame.KMOD_CTRL:153                        self.handle_zoom_out()154            if self._start is not None and self._end is not None:155                path = self.metro_map.optimized_route(start=self._start.name,156                                                      destination=self._end.name,157                                                      optimization=self._curr_opt)158                self._connect_final_route(path)159            else:160                for node in self.metro_map.get_all_nodes(''):161                    visited.add(node)162                    transform_node = self.scale_factor_transformations(node.coordinates)163                    for u in node.get_neighbours():164                        if u not in visited:165                            transform_u = self.scale_factor_transformations(u.coordinates)166                            if transform_u[0] <= WIDTH and transform_node[0] <= WIDTH:167                                pygame.draw.line(self._screen, node.get_color(u),168                                                 transform_node,169                                                 transform_u, 3)170            self.hover_display()171            pygame.display.update()172    def node_exists(self, coordinates: tuple[float, float]) -> Optional[Node]:173        """Return the node if it exists at given coordinates. Else, return None.174        """175        for node in self.metro_map.get_all_nodes():176            if self.scale_factor_transformations(node.coordinates) == coordinates:177                return node178        return None179    def hover_display(self) -> None:180        """Gains the current nodes which can be displayed through181        the self.active_nodes attribute. Provides information on both name and zone."""182        for node in self.metro_map.get_all_nodes('station'):183            transformed = self.scale_factor_transformations(node.coordinates)184            if node == self._start and self._start is not None:185                show = node.name + ' ' + '(' + node.zone + ')' + ' START'186                draw_text(self._screen, show, 17,187                          (transformed[0] + 4, transformed[1] - 15), THECOLORS['green'])188            elif node == self._end and self._start is not None:189                show = node.name + ' ' + '(' + node.zone + ')' + ' END'190                draw_text(self._screen, show, 17,191                          (transformed[0] + 4, transformed[1] - 15), THECOLORS['red'])192            elif in_circle(5, transformed, pygame.mouse.get_pos()):193                show = node.name + ' ' + '(' + node.zone + ')'194                draw_text(self._screen, show, 17,195                          (transformed[0] + 4, transformed[1] - 15))...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
