From 9873fb7fa0c9c5198af9c269b5cb69dfdd0f7b43 Mon Sep 17 00:00:00 2001
From: Nathalie Vialaneix <nathalie.vialaneix@inrae.fr>
Date: Sat, 11 Mar 2023 15:21:29 +0100
Subject: [PATCH] plots: added a quality plot function

---
 R/plot_functions.R | 64 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 64 insertions(+)

diff --git a/R/plot_functions.R b/R/plot_functions.R
index 35fd158..8e85796 100644
--- a/R/plot_functions.R
+++ b/R/plot_functions.R
@@ -2,6 +2,8 @@
 #' @importFrom reshape2 melt
 #' @import dendextend
 #' @importFrom RColorBrewer brewer.pal
+#' @importFrom stats as.dendrogram
+#' @importFrom graphics abline
 
 plot_dendrogram <- function(x) {
   
@@ -310,6 +312,68 @@ plot_selection <- function(x, sel.type, threshold) {
   return(p)
 }
 
+plot_quality <- function(x, quality.crit) {
+  
+  if ("quality" %in% names(x)) {
+    valid_criteria <- c("mse", "Precision", "Recall", "ARI", "NMI")
+  } else valid_criteria <- "mse"
+  
+  crit_ok <- all(sapply(quality.crit, function(cc) cc %in% valid_criteria))
+  if (!crit_ok || length(quality.crit) > 2) {
+    stop(paste0("'quality.crit' must be a vector with length at most 2 in ",
+                paste(valid_criteria, collapse = ", "), "."))
+  }
+  
+  if (length(quality.crit) == 1) {
+    if (quality.crit == "mse") {
+      df <- data.frame("criterion" = x$mse$mse, "at" = x$mse$clust)
+      ylimits <- c(0, max(df$criterion))
+    } else {
+      df <- data.frame("criterion" = x$quality[, quality.crit], 
+                       "at" = as.numeric(x$quality$clust))
+      if (quality.crit == "ARI") {
+        ylimits <- c(-1, 1)
+      } else ylimits <- c(0, 1)
+    }
+    p <- ggplot(df, aes(x = at, y = criterion)) + 
+      geom_jitter(width=0.2, height = 0) + theme_bw() + 
+      xlab("number of intervals") + ylab(quality.crit) + ylim(ylimits) +
+      scale_x_continuous(breaks = unique(df$at), 
+                         limits = c(min(df$at) - 0.5, max(df$at + 0.5)))
+  } else {
+    if ("mse" %in% quality.crit) {
+      quality.crit <- setdiff(quality.crit, "mse")
+      df <- data.frame("x" = x$quality[, quality.crit], "y" = x$mse$mse, 
+                       "at" = as.factor(x$quality$clust))
+      quality_names <- c(quality.crit, "mse")
+      if (quality.crit == "ARI") {
+        xlimits <- c(-1, 1)
+      } else xlimits <- c(0, 1)
+      ylimits <- c(0, max(df$y))
+    } else {
+      df <- data.frame("x" = x$quality[, quality.crit[1]], 
+                       "y" = x$quality[, quality.crit[2]], 
+                       "at" = as.factor(x$quality$clust))
+      quality_names <- quality.crit
+      if (quality.crit[1] == "ARI") {
+        xlimits <- c(-1, 1)
+      } else xlimits <- c(0, 1)
+      if (quality.crit[2] == "ARI") {
+        ylimits <- c(-1, 1)
+      } else ylimits <- c(0, 1)
+    }
+    p <- ggplot(df, aes(x = x, y = y, colour = at)) + geom_point() + 
+      theme_bw() + xlab(quality_names[1]) + ylab(quality_names[2]) + 
+      xlim(xlimits) + ylim(ylimits) + 
+      scale_colour_discrete(name = "# intervals")
+    if (length(unique(df$at)) > 20) {
+      p <- p + theme(legend.position = "none")
+    }
+  }
+  
+  return(p)
+}
+
 compute_sumimp <- function(importances, func_name, var_names) {
   FUN <- eval(func_name)
   impsummary <- lapply(importances, function(alist) apply(alist, 1, FUN))
-- 
GitLab