K-mean clustering and the Security domain.
Clustering
Clustering is an unsupervised algorithm to discover groups of similar things, ideas, or people. Unlike supervised algorithms, we're not training clustering algorithms with examples of known labels. Instead, clustering tries to find structures within a training set where no point of the data is the label.
K-Means Clustering
Suppose our goal is to find a few similar groups in a dataset like:
K-Means begins with k randomly placed centroids. Centroids, as their name suggests, are the center points of the clusters. For example, here we're adding four random centroids:
Then we assign each existing data point to its nearest centroid:
After the assignment, we move the centroids to the average location of points assigned to it. Remember, centroids are supposed to be the center points of clusters:
The current iteration concludes each time we're done relocating the centroids. We repeat these iterations until the assignment between multiple consecutive iterations stops changing:
When the algorithm terminates, those four clusters are found as expected. Now that we know how K-Means works, let's implement it in Java.
Why use machine learning with cybersecurity
It’s almost impossible for an analyst looking at a time series chart of network traffic to draw any conclusion from what they are looking at. Why? People can’t see more than three dimensions. And, too many false alerts cause analysts to simply ignore some of what they’re seeing—too much noise.
But machine learning makes it possible to flush out, for example, criminal hackers who are stealing data from your system and transmitting it to their command and control center. This is what intrusion detection systems are supposed to do, but hackers use all kinds of techniques to avoid detection by traditional cybersecurity systems. For example, they could transmit stolen data in small pieces and send each to a different IP address, such as hijacked home computer users, then use those hijacked computers to send those pieces to their hacker command and control center.
Machine learning helps us distill dozens or hundreds of data points into one or two metrics. Then, we can build our charts and alerts around those. Now, those alerts are significantly more valuable.
In this example we’ll illustrate one approach to looking at network traffic. We use router logs provided by Verizon from the Bro-type router. We’ll group each record into one of seven clusters, then we’ll look at traffic in those clusters with the smaller number of entries. That, by definition, are our outliers.
The code, explained
The code is available here, and the data here. This is data from a network analysis tool called Zeek, formerly called Bro.
The University of Cincinnati provides this description of the columns in this data:
- ts—time; timestamp
- uid—string; unique ID of connection
- orig_h—addr; originating endpoint’s IP address (aka ORIG)
- orig_p—port; originating endpoint’s TCP/UDP port or ICMP code
- resp_h—addr; responding endpoint’s IP address (aka RESP)
- resp_p—port; responding endpoint’s TCP/UDP port or ICMP code
- proto—transport_protoTransport layer protocol of connection
- service—string; dynamically detected application protocol, if any
- duration—interval; time of last packet seen to time of first packet seen
- orig_bytes—count; originator payload bytes, from sequence numbers if TCP
- resp_bytes—count; responder payload bytes, from sequence numbers if TCP
- conn_state—string; connection state (see conn.log:conn_state table)
- local_orig—bool; if conn originated locally T; if remotely F. If Site::local_nets empty, always unset
- missed_bytes—count; number of missing bytes in content gaps
- history—string; connection state history (see conn.log:history table)
- orig_pkts—count; number of ORIG packets
- orig_ip_bytes—count; number of ORIG IP bytes (via IP total_length header field)
- resp_pkts—count; number of RESP packets
- resp_ip_bytes—count; number of RESP IP bytes (via IP total_length header field)
- tunnel_parents—set; If tunneled, connection UID of encapsulating parent (s)
- orig_cc—string; ORIG GeoIP country dode
- resp_cc—string; RESP GeoIP country code
First, we load the csv file into a Spark dataframe.
from pyspark.sql.types import StructType, StructField, FloatType, BooleanType from pyspark.sql.types import DoubleType, IntegerType, StringType import pyspark from pyspark.ml.linalg import Vectors from pyspark.ml.feature import VectorAssembler from pyspark.sql.functions import lit from pyspark.sql.functions import udf, concat from pyspark import SQLContext conf = pyspark.SparkConf() sc = pyspark.SparkContext.getOrCreate(conf=conf) sqlcontext = SQLContext(sc) schema = StructType([ StructField("ts", StringType(),True), StructField("uid", StringType(),True), StructField("origh", StringType(),True), StructField("origp", StringType(),True), StructField("resph", StringType(),True), StructField("respp", StringType(),True), StructField("proto", StringType(),True), StructField("service" , StringType(),True), StructField("duration", FloatType(),True), StructField("origbytes", StringType(),True), StructField("respbytes", StringType(),True), StructField("connstate", StringType(),True), StructField("localorig", StringType(),True), StructField("missedbytes", StringType(),True), StructField("history", StringType(),True), StructField("origpkts", IntegerType(),True), StructField("origipbytes", IntegerType(),True), StructField("resppkts", IntegerType(),True), StructField("respipbytes", IntegerType(),True), StructField("tunnelparents", StringType(),True) ]) df = sqlcontext.read.csv(path="/home/ubuntu/Documents/forensics/bigger.log", sep="\t", schema=schema) df2 = df.fillna(0)
Next, we register a UDF (user defined function). We will use this to turn all the fields sent to this function into integers because machine learning, for the most part, only works with numbers.
colsInt = udf(lambda z: toInt(z), IntegerType()) sqlcontext.udf.register("colsInt", colsInt) def toInt(s): if not s: return 0 if isinstance(s, str) == True: st = [str(ord(i)) for i in s] return(int(''.join(st))) else: return s
Now, we create some additional columns which are the columns we have selected to feed into our model. For each of these, we will call the colsInt() UDF to convert those to numbers.
You could vary the choice of columns according to what hypotheses you want to follow. For example, below we look at the ports and traffic as well as the protocol.
- There might be other metrics in that log that we could add or remove.
- We should probably leave the destination IP address out of the model because of the hacker’s ability to hide their true destination.
- We might drop the UDP protocol since sftp (which is TCP) would be the protocol they would use to transmit that.
- Or, we could include the time of day in the local time zone to isolate after-hours events.
It all depends on what kind of activity you want to focus on.
Note that each of the .withColumn() statements create a new dataframe. This is because Spark dataframes are immutable.
a = df2.withColumn( 'iorigp',colsInt('origp')) c = a.withColumn( 'irespp',colsInt('respp')) d = c.withColumn( 'iproto',colsInt('proto')) e = d.withColumn('iorigh',colsInt('origh')) f = e.withColumn( 'iorigbytes',colsInt( 'origbytes')) g = f.withColumn( 'irespbytes',colsInt('respbytes')) h = g.withColumn( 'iorigpkts',colsInt( 'origpkts')) i = h.withColumn( 'iorigipbytes',colsInt('origipbytes')) columns = ['iorigp','irespp','iproto', 'iorigbytes','irespbytes','iorigpkts','iorigipbytes']
The next step adds a column to our dataframe called features. This is a tuple of the columns we have selected. The K-means algorithm will expect there to be a features column.
vecAssembler = VectorAssembler(inputCols=columns, outputCol="features") router = vecAssembler.transform(i)
Here, we use the K-means algorithm. One nice thing about Apache Spark is its machine learning algorithms are easy to use. They don’t require the reprocessing and reshaping that other frameworks do, and they work with Spark dataframes, so we could work with much larger sets of data. (Pandas does not scale like Spark dataframes do.)
from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import ClusteringEvaluator kmeans = KMeans().setK(7).setSeed(1) model = kmeans.fit(router) predictions = model.transform(router) p = predictions.groupby('prediction').count() q = p.toPandas()
We have grouped the observations into 7 clusters. Cluster 0 has 40,303 router records, but cluster 2 has only 171. Clearly, those are outliers, so this is where we focus our cybersecurity analysis.
We can plot that as a bar chart to further show how the data is clustered.
from plotly.offline import plot import pandas as pd import plotly.graph_objects as go fig = go.Figure( data=[go.Bar(x=q['prediction'],y=q['count'])], layout_title_text="K Means Count" ) fig.show()
Plotly uses JavaScript to create popups to give you more information where you place the cursor. We’ve placed it at the point (prediction=2,count=121).
So, let’s make a new dataframe of just those records in cluster 2. (It’s actually row index 5 in the dataframe, so don’t confuse those two concepts.)
suspect = predictions.filter("prediction == 2")
Here we convert the output to Pandas, simply because the Jupyter notebook displays that data more clearly than it does dataframes, where it tends to chop off wide columns, making them hard to read.
You can see the same IP address shown more than a few times, which is probably a good place for further analysis. Look and see which machine it is and to whom it connects.
Comments
Post a Comment