From 9873fb7fa0c9c5198af9c269b5cb69dfdd0f7b43 Mon Sep 17 00:00:00 2001 From: Nathalie Vialaneix <nathalie.vialaneix@inrae.fr> Date: Sat, 11 Mar 2023 15:21:29 +0100 Subject: [PATCH] plots: added a quality plot function --- R/plot_functions.R | 64 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/R/plot_functions.R b/R/plot_functions.R index 35fd158..8e85796 100644 --- a/R/plot_functions.R +++ b/R/plot_functions.R @@ -2,6 +2,8 @@ #' @importFrom reshape2 melt #' @import dendextend #' @importFrom RColorBrewer brewer.pal +#' @importFrom stats as.dendrogram +#' @importFrom graphics abline plot_dendrogram <- function(x) { @@ -310,6 +312,68 @@ plot_selection <- function(x, sel.type, threshold) { return(p) } +plot_quality <- function(x, quality.crit) { + + if ("quality" %in% names(x)) { + valid_criteria <- c("mse", "Precision", "Recall", "ARI", "NMI") + } else valid_criteria <- "mse" + + crit_ok <- all(sapply(quality.crit, function(cc) cc %in% valid_criteria)) + if (!crit_ok || length(quality.crit) > 2) { + stop(paste0("'quality.crit' must be a vector with length at most 2 in ", + paste(valid_criteria, collapse = ", "), ".")) + } + + if (length(quality.crit) == 1) { + if (quality.crit == "mse") { + df <- data.frame("criterion" = x$mse$mse, "at" = x$mse$clust) + ylimits <- c(0, max(df$criterion)) + } else { + df <- data.frame("criterion" = x$quality[, quality.crit], + "at" = as.numeric(x$quality$clust)) + if (quality.crit == "ARI") { + ylimits <- c(-1, 1) + } else ylimits <- c(0, 1) + } + p <- ggplot(df, aes(x = at, y = criterion)) + + geom_jitter(width=0.2, height = 0) + theme_bw() + + xlab("number of intervals") + ylab(quality.crit) + ylim(ylimits) + + scale_x_continuous(breaks = unique(df$at), + limits = c(min(df$at) - 0.5, max(df$at + 0.5))) + } else { + if ("mse" %in% quality.crit) { + quality.crit <- setdiff(quality.crit, "mse") + df <- data.frame("x" = x$quality[, quality.crit], "y" = x$mse$mse, + "at" = as.factor(x$quality$clust)) + quality_names <- c(quality.crit, "mse") + if (quality.crit == "ARI") { + xlimits <- c(-1, 1) + } else xlimits <- c(0, 1) + ylimits <- c(0, max(df$y)) + } else { + df <- data.frame("x" = x$quality[, quality.crit[1]], + "y" = x$quality[, quality.crit[2]], + "at" = as.factor(x$quality$clust)) + quality_names <- quality.crit + if (quality.crit[1] == "ARI") { + xlimits <- c(-1, 1) + } else xlimits <- c(0, 1) + if (quality.crit[2] == "ARI") { + ylimits <- c(-1, 1) + } else ylimits <- c(0, 1) + } + p <- ggplot(df, aes(x = x, y = y, colour = at)) + geom_point() + + theme_bw() + xlab(quality_names[1]) + ylab(quality_names[2]) + + xlim(xlimits) + ylim(ylimits) + + scale_colour_discrete(name = "# intervals") + if (length(unique(df$at)) > 20) { + p <- p + theme(legend.position = "none") + } + } + + return(p) +} + compute_sumimp <- function(importances, func_name, var_names) { FUN <- eval(func_name) impsummary <- lapply(importances, function(alist) apply(alist, 1, FUN)) -- GitLab