Source code for anomalydetection.backend.stream.kafka

# -*- coding:utf-8 -*-
#
# Anomaly Detection Framework
# Copyright (C) 2018 Bluekiri BigData Team <bigdata@bluekiri.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import os
import warnings
from typing import Generator
from multiprocessing import Queue as MultiprocessingQueue
from queue import Queue as Queue

from kafka import KafkaConsumer, KafkaProducer

from anomalydetection import BASE_PATH
from anomalydetection.backend.entities.input_message import InputMessage
from anomalydetection.backend.stream import BaseStreamAggregation
from anomalydetection.backend.stream import BaseStreamConsumer
from anomalydetection.backend.stream import BaseStreamProducer
from anomalydetection.backend.stream.agg.functions import AggregationFunction
from anomalydetection.common.concurrency import Concurrency
from anomalydetection.common.logging import LoggingMixin


[docs]class KafkaStreamConsumer(BaseStreamConsumer, LoggingMixin): def __init__(self, broker_servers: str, input_topic: str, group_id: str) -> None: """ KafkaStreamConsumer constructor :param broker_servers: broker servers :param input_topic: input topic :param group_id: consumer group id """ super().__init__() self.broker_servers = broker_servers.split(",") self.topic = input_topic self.group_id = group_id self.subscribed = True self._kafka_consumer = KafkaConsumer( self.topic, bootstrap_servers=self.broker_servers, group_id=self.group_id) self._kafka_consumer.subscribe([self.topic])
[docs] def unsubscribe(self): self.subscribed = False self._kafka_consumer.unsubscribe()
[docs] def poll(self) -> Generator: while self.subscribed: self.logger.debug("Polling messages (auto ack). START") try: for msg in self._kafka_consumer: message = msg.value.decode('utf-8') self.logger.debug("Message received: {}".format(message)) yield message except Exception as ex: self.logger.error("Error polling messages.", ex) self.logger.debug("Polling messages. END")
def __str__(self) -> str: return "Kafka topic: brokers: {}, topic: {}".format( self.broker_servers, self.topic)
[docs]class KafkaStreamProducer(BaseStreamProducer, LoggingMixin): def __init__(self, broker_servers: str, output_topic: str) -> None: """ KafkaStreamProducer constructor :param broker_servers: broker servers :param output_topic: topic to write to """ super().__init__() self.broker_servers = broker_servers.split(",") self.output_topic = output_topic self.kafka_producer = KafkaProducer( bootstrap_servers=self.broker_servers, api_version=(0, 10))
[docs] def push(self, message: str) -> None: try: self.logger.debug("Pushing message: {}.".format(message)) self.kafka_producer.send(self.output_topic, bytearray(message, 'utf-8')) except Exception as ex: self.logger.error("Pushing message failed.", ex)
def __str__(self) -> str: return "Kafka topic: brokers: {}, topic: {}".format( self.broker_servers, self.output_topic)
[docs]class SparkKafkaStreamConsumer(BaseStreamConsumer, BaseStreamAggregation, LoggingMixin): def __init__(self, broker_servers: str, input_topic: str, group_id: str, agg_function: AggregationFunction, agg_window_millis: int, spark_opts: dict={}, multiprocessing=True) -> None: """ SparkKafkaStreamConsumer constructor :param broker_servers: broker servers :param input_topic: input topic :param group_id: consumer group id :param agg_function: aggregation function to apply :param agg_window_millis: aggregation window in milliseconds :param spark_opts: spark options dict :param multiprocessing: use multiprocessing instead of threading """ super().__init__(agg_function, agg_window_millis) self.broker_servers = broker_servers.split(",") self.input_topic = input_topic self.group_id = group_id self.spark_opts = spark_opts self.subscribed = True self.multiprocessing = multiprocessing if self.multiprocessing: self.queue = MultiprocessingQueue() else: self.queue = Queue() def run_spark_job(queue: Queue, _agg_function: AggregationFunction, _agg_window_millis: int, _spark_opts: dict = {}, _environment: dict = {}): os.environ.update(_environment) try: try: import findspark findspark.init() except Exception as ex: self.logger.warn("Cannot import Spark pyspark with" " findspark. Message: {}".format(str(ex))) pass from pyspark.sql import SparkSession from pyspark.streaming import StreamingContext from pyspark.streaming.kafka import KafkaUtils from pyspark.sql.functions import expr, window spark_builder = SparkSession \ .builder \ for k in _spark_opts: spark_builder = spark_builder.config(k, _spark_opts[k]) spark_builder = spark_builder \ .appName(str(self)) \ .config("spark.jars.packages", "org.apache.spark:spark-streaming-kafka-0-8_2.11:2.2.1," "org.apache.bahir:spark-streaming-pubsub_2.11:2.2.1") \ .config("spark.jars", BASE_PATH + "/lib/streaming-pubsub-serializer_2.11-0.1.jar") spark = spark_builder.getOrCreate() spark.sparkContext.setLogLevel("WARN") ssc = StreamingContext(spark.sparkContext, (agg_window_millis / 1000)) agg = expr("value") if _agg_function == AggregationFunction.AVG: agg = expr("avg(value)") elif _agg_function == AggregationFunction.SUM: agg = expr("sum(value)") elif _agg_function == AggregationFunction.COUNT: agg = expr("count(value)") elif _agg_function == AggregationFunction.P50: agg = expr("percentile(value, 0.5)") elif _agg_function == AggregationFunction.P75: agg = expr("percentile(value, 0.75)") elif _agg_function == AggregationFunction.P95: agg = expr("percentile(value, 0.95)") elif _agg_function == AggregationFunction.P99: agg = expr("percentile(value, 0.99)") kafka_stream = KafkaUtils.createDirectStream( ssc, [self.input_topic], {"metadata.broker.list": ",".join(self.broker_servers)}) def aggregate_rdd(_queue, _agg, df, ts): secs = int(self.agg_window_millis / 1000) win = window("ts", "{} seconds".format(secs)) if df.first(): aggs = df \ .groupBy("application", win) \ .agg(_agg.alias("value")) \ .collect() for row in aggs: message = InputMessage(row["application"], value=row["value"], ts=ts) self.logger.debug( "Enqueue: {}".format(message.to_json())) try: _queue.put(message.to_json()) except AssertionError as ex: self.logger.warn(str(ex)) else: warnings.warn("Empty RDD") # Create kafka stream kafka_stream \ .map(lambda x: x[1]) \ .foreachRDD(lambda ts, rdd: aggregate_rdd(queue, agg, spark.read.json(rdd), ts)) # Run ssc.start() if "timeout" in _spark_opts: ssc.awaitTerminationOrTimeout(_spark_opts["timeout"]) ssc.stop() spark.stop() else: ssc.awaitTermination() ssc.stop() spark.stop() except Exception as e: raise e # Run in multiprocessing, each aggregation runs a spark driver. runner = Concurrency.run_process \ if self.multiprocessing \ else Concurrency.run_thread Concurrency.get_lock("spark").acquire() pid = runner(target=run_spark_job, args=(self.queue, self.agg_function, self.agg_window_millis, self.spark_opts, os.environ.copy()), name="PySpark {}".format(str(self))) Concurrency.schedule_release("spark", 30) self.pid = pid
[docs] def unsubscribe(self): self.subscribed = False if isinstance(self.queue, MultiprocessingQueue): self.queue.close() self.queue.join_thread() elif isinstance(self.queue, Queue): self.queue.join()
[docs] def poll(self) -> Generator: while self.subscribed: try: message = self.queue.get(timeout=2) yield message except Exception as _: pass
def __str__(self) -> str: return "Kafka aggregated topic: " \ "brokers: {}, topic: {}, func: {}, window: {}ms".format( self.broker_servers, self.input_topic, self.agg_function.name, self.agg_window_millis)