In [1]:
import os
import sys

spark_path = "/Users/flavio.clesio/Documents/spark-2.1.0" 

os.environ['SPARK_HOME'] = spark_path
os.environ['HADOOP_HOME'] = spark_path

sys.path.append(spark_path + "/bin")
sys.path.append(spark_path + "/python")
sys.path.append(spark_path + "/python/pyspark/")
sys.path.append(spark_path + "/python/lib")
sys.path.append(spark_path + "/python/lib/pyspark.zip")
sys.path.append(spark_path + "/python/lib/py4j-0.10.4-src.zip") # Must be the same version of your Spark Version

In [2]:
from pyspark import SparkContext
from pyspark import SparkConf

In [3]:
conf = (SparkConf()
 .setMaster("local")
 .setAppName("Survival Regression Example")
 .set("spark.executor.memory", "1g"))

In [4]:
sc = SparkContext(conf = conf)

In [5]:
sc


Out[5]:
<pyspark.context.SparkContext at 0x106315690>

In [9]:
from pyspark.ml.regression import AFTSurvivalRegression
from pyspark.ml.linalg import Vectors
from pyspark.sql import SQLContext

In [11]:
sqlContext = SQLContext(sc)

In [12]:
training = sqlContext.createDataFrame([
    (1.218, 1.0, Vectors.dense(1.560, -0.605)),
    (2.949, 0.0, Vectors.dense(0.346, 2.158)),
    (3.627, 0.0, Vectors.dense(1.380, 0.231)),
    (0.273, 1.0, Vectors.dense(0.520, 1.151)),
    (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"])

In [13]:
quantileProbabilities = [0.3, 0.6]

In [15]:
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities
                            ,quantilesCol="quantiles")

In [16]:
model = aft.fit(training)

In [17]:
print("Coefficients: " + str(model.coefficients))


Coefficients: [-0.496311146665,0.198444376999]

In [18]:
print("Intercept: " + str(model.intercept))


Intercept: 2.6380946151

In [19]:
print("Scale: " + str(model.scale))


Scale: 1.54723455744

In [20]:
model.transform(training).show(truncate=False)


+-----+------+--------------+------------------+--------------------------------------+
|label|censor|features      |prediction        |quantiles                             |
+-----+------+--------------+------------------+--------------------------------------+
|1.218|1.0   |[1.56,-0.605] |5.718979487635007 |[1.1603238947151664,4.99545601027477] |
|2.949|0.0   |[0.346,2.158] |18.07652118149533 |[3.667545845471739,15.789611866277625]|
|3.627|0.0   |[1.38,0.231]  |7.381861804239096 |[1.4977061305190829,6.44796261233896] |
|0.273|1.0   |[0.52,1.151]  |13.577612501425284|[2.7547621481506854,11.8598722240697] |
|4.199|0.0   |[0.795,-0.226]|9.013097744073898 |[1.8286676321297826,7.87282650587843] |
+-----+------+--------------+------------------+--------------------------------------+


In [ ]:


In [ ]:


In [ ]: