Spark MLlib线性回归代码实现及结果展示

2018/3/8 18:55:01 人评论 次浏览 分类:大数据

代码实现:
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.DataFrame
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression.LinearRegression
/**
  * Created by zhen on 2018/3/10.
  */
object LinearRegression {

  def main(args: Array[String]) {
    //设置环境
  val spark = SparkSession.builder ().appName ("SendBroadcast").master ("local[2]").getOrCreate()
    val sc = spark.sparkContext
    val sqlContext = spark.sqlContext
    //准备训练集合

    val raw_data = sc.textFile("src/sparkMLlib/man.txt")
    val map_data = raw_data.map{x=>
      val mid = x.replaceAll(","," ,")
      val split_list = mid.substring(0,mid.length-1).split(",")
      for(x <- 0 until split_list.length){
        if(split_list(x).trim.equals("")) split_list(x) = "0.0" else split_list(x) = split_list(x).trim
      }
      ( split_list(1).toDouble,split_list(2).toDouble,split_list(3).toDouble,split_list(4).toDouble,
        split_list(5).toDouble,split_list(6).toDouble,split_list(7).toDouble,split_list(8).toDouble,
        split_list(9).toDouble,split_list(10).toDouble,split_list(11).toDouble)
    }
    val mid = map_data.sample(false,0.6,0)//随机取样,训练模型
    val df = sqlContext.createDataFrame(mid)
    val colArray = Array("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11")
    val data = df.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11")
    val assembler = new VectorAssembler().setInputCols(colArray).setOutputCol("features")
    val vecDF = assembler.transform(data)
    //准备预测集合
    val map_data_for_predict = map_data
    val df_for_predict = sqlContext.createDataFrame(map_data_for_predict)
    val data_for_predict = df_for_predict.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11")
    val colArray_for_predict = Array("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11")
    val assembler_for_predict = new VectorAssembler().setInputCols(colArray_for_predict).setOutputCol("features")
    val vecDF_for_predict: DataFrame = assembler_for_predict.transform(data_for_predict)
    // 建立模型,进行预测
    // 设置线性回归参数
    val lr1 = new LinearRegression()
    val lr2 = lr1.setFeaturesCol("features").setLabelCol("c5").setFitIntercept(true)
    // RegParam:正则化
    val lr3 = lr2.setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
    // 将训练集合代入模型进行训练
    val lrModel = lr3.fit(vecDF)
    // 输出模型全部参数
    lrModel.extractParamMap()
    //coefficients 系数 intercept 截距
    println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
    // 模型进行评价
    val trainingSummary = lrModel.summary
    trainingSummary.residuals.show()
    println(s"均方根差: ${trainingSummary.rootMeanSquaredError}")//RMSE:均方根差
    println(s"判定系数: ${trainingSummary.r2}")//r2:判定系数,也称为拟合优度,越接近1越好
    val predictions = lrModel.transform(vecDF_for_predict)
    val predict_result = predictions.selectExpr("features","c5", "round(prediction,1) as prediction")
    predict_result.rdd.saveAsTextFile("src/sparkMLlib/manResult")
    sc.stop()
  }
}

性能评估:

均方根差: 0.2968176690349843
判定系数: 0.9715059814474793

结果:

[[4.61,1.51,5.91,4.18,3.91,0.0,7.83,0.0,4.81,4.71,3.44,0.0,3.61,3.76],1.51,1.7]
[[3.1,3.64,1.6,2.57,3.16,0.0,5.6,0.0,1.84,2.77,0.0,2.4,0.0,2.53],3.64,3.4]
[[3.15,4.24,2.89,1.94,3.81,0.0,6.12,0.0,0.0,0.0,2.23,0.0,2.51,3.98],4.24,3.9]
[[2.13,3.81,3.5,3.29,3.47,0.0,0.0,0.0,2.16,2.06,1.65,0.0,3.37,3.93],3.81,3.6]
[[3.6,4.36,2.89,3.46,3.66,0.0,7.17,0.0,2.86,2.58,0.0,2.73,2.73,3.94],4.36,4.0]
[[2.65,3.58,3.9,3.63,2.71,0.0,5.91,0.0,3.63,3.08,2.33,0.0,1.79,2.54],3.58,3.4]
[(14,[0,1,2,3,4,6,8,9],[2.13,2.7,2.26,1.78,2.82,7.15,2.69,2.46]),2.7,2.6]
[[2.31,2.42,4.0,3.27,3.69,0.0,5.87,0.0,0.0,0.0,1.32,0.0,1.32,2.09],2.42,2.4]
[(14,[0,1,2,3,6,10,12,13],[3.4,4.12,3.04,2.76,9.55,1.44,3.61,3.95]),4.12,3.8]

相关知识

  • SparkStreaming与Kafka整合遇到的问题及解决方案

    前言 最近工作中是做日志分析的平台,采用了sparkstreaming+kafka,采用kafka主要是看中了它对大数据量处理的高性能,处理日志类应用再好不过了,采用了sparkstreaming的流处理框架 主要是考虑到它本身是基于spark核心的,以后的批处理可以一站式服务,并且可以提供准实时服…

    2017/7/20 11:45:03
  • spark极简入门

    1.windows上下载安装sbt 去sbt官网下载 sbt包,解压到指定目录,不需要安装。记得配置环境变量。 新建 SBT_HOME ,值是sbt包的解压路径,比如C:\Users\***\Tools\sbt-0.13.15\sbt(建议不要放在C盘) 并在path 中添加 %SBT_HOME%\bin 查看是否成功,命令行输入: sbt sbtVer…

    2017/7/20 11:45:03
  • Redis精华

    Redis的复制功能是完全建立在之前我们讨论过的基于内存快照的持久化策略基础上的,也就是说无论你的持久化策略选择的是什么,只要用到了redis的复制功能,就一定会有内存快照发生,那么首先要注意你的系统内存容量规划,原因可以参考我上一篇文章中提到的Redis磁盘IO问题。R…

    2017/7/20 11:45:03
  • 快速搭建 ELK + OpenWAF 环境

    摘要: OpenWAF是第一个全方位开源的Web应用防护系统; ELK 是比较火的开源日志分析系统; 本节主要介绍,ELK 的 docker 部署及与 OpenWAF 的结合 OpenWAF简介 OpenWAF是第一个全方位开源的Web应用防护系统(WAF),他基于nginx_lua API分析HTTP请求信息。OpenWAF由行为分析引擎…

    2017/7/20 11:45:03

共有访客发表了评论 网友评论

验证码: 看不清楚?