Description

Gaussian Mixture is a kind of clustering algorithm.

Gaussian Mixture clustering performs expectation maximization for multivariate Gaussian Mixture Models (GMMs). A GMM represents a composite distribution of independent Gaussian distributions with associated “mixing” weights specifying each’s contribution to the composite.

Given a set of sample points, this class will maximize the log-likelihood for a mixture of k Gaussians, iterating until the log-likelihood changes by less than convergenceTol, or until it has reached the max number of iterations. While this process is generally guaranteed to converge, it is not guaranteed to find a global optimum.

Parameters

Name Description Type Required? Default Value
tol Iteration tolerance. Double 0.01
vectorCol Name of a vector column String
k Number of clusters. Integer 2
maxIter Maximum iterations, The default value is 100 Integer 100
vectorCol Name of a vector column String
predictionCol Column name of prediction. String
predictionDetailCol Column name of prediction result, it will include detailed info. String
reservedCols Names of the columns to be retained in the output table String[] null

Script Example

Code

  1. data = np.array([
  2. ["-0.6264538 0.1836433"],
  3. ["-0.8356286 1.5952808"],
  4. ["0.3295078 -0.8204684"],
  5. ["0.4874291 0.7383247"],
  6. ["0.5757814 -0.3053884"],
  7. ["1.5117812 0.3898432"],
  8. ["-0.6212406 -2.2146999"],
  9. ["11.1249309 9.9550664"],
  10. ["9.9838097 10.9438362"],
  11. ["10.8212212 10.5939013"],
  12. ["10.9189774 10.7821363"],
  13. ["10.0745650 8.0106483"],
  14. ["10.6198257 9.9438713"],
  15. ["9.8442045 8.5292476"],
  16. ["9.5218499 10.4179416"],
  17. ])
  18. df_data = pd.DataFrame({
  19. "features": data[:, 0],
  20. })
  21. data = dataframeToOperator(df_data, schemaStr='features string', op_type='batch')
  22. gmm = GaussianMixture() \
  23. .setPredictionCol("cluster_id") \
  24. .setVectorCol("features") \
  25. .setPredictionDetailCol("cluster_detail")
  26. .setTol(0.)
  27. gmm.fit(data).transform(data).print()

Results

  1. features cluster_id cluster_detail
  2. 0 -0.6264538 0.1836433 0 1.0 4.275273913994647E-92
  3. 1 -0.8356286 1.5952808 0 1.0 1.0260377730322135E-92
  4. 2 0.3295078 -0.8204684 0 1.0 1.0970173367582936E-80
  5. 3 0.4874291 0.7383247 0 1.0 3.30217313232611E-75
  6. 4 0.5757814 -0.3053884 0 1.0 3.163811360527691E-76
  7. 5 1.5117812 0.3898432 0 1.0 2.1018052308786076E-62
  8. 6 -0.6212406 -2.2146999 0 1.0 6.772270268625197E-97
  9. 7 11.1249309 9.9550664 1 3.1567838012477083E-56 1.0
  10. 8 9.9838097 10.9438362 1 1.9024447346702333E-51 1.0
  11. 9 10.8212212 10.5939013 1 2.8009730987296404E-56 1.0
  12. 10 10.9189774 10.7821363 1 1.7209132744891575E-57 1.0
  13. 11 10.0745650 8.0106483 1 2.864269663513225E-43 1.0
  14. 12 10.6198257 9.9438713 1 5.77327399194046E-53 1.0
  15. 13 9.8442045 8.5292476 1 2.5273123050926845E-43 1.0
  16. 14 9.5218499 10.4179416 1 1.7314580596765865E-46 1.0