diff --git a/NAMESPACE b/NAMESPACE
index f960cb89b54b814fa97d7edd8c77f6442327442e..b33a6ec79d5ef141164da7aa97e9cafd78d24954 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -1,12 +1,9 @@
 # Generated by roxygen2: do not edit by hand
 
-export(fisher_mix)
 export(treediff)
 importFrom(dplyr,"%>%")
 importFrom(dplyr,group_by)
 importFrom(dplyr,summarise)
 importFrom(limma,squeezeVar)
 importFrom(stats,cophenetic)
-importFrom(stats,na.omit)
 importFrom(stats,pt)
-importFrom(stats,rf)
diff --git a/R/treediff.R b/R/treediff.R
index 172e847142d75cca3f39739f707da9ab01e82095..fdb6803a393169d37db522a6f11f0e1b664da9fb 100644
--- a/R/treediff.R
+++ b/R/treediff.R
@@ -3,29 +3,29 @@
 #' @description Perform the treediff test to compare two sets of trees.
 #'
 #' @details This function compares two sets of trees using a p-value aggregation
-#' method. The p-values are obtained by the treediff method, as described in 
+#' method. The p-values are obtained by the treediff method, as described in
 #' (Neuvial \emph{et al.}, 2023).
 #'
 #' @param trees1 A list of trees corresponding to the first condition (set).
-#' Trees are structured into groups (or clusters) with the same number of 
-#' replicates in each group. Trees are ordered by groups and then by replicates: 
+#' Trees are structured into groups (or clusters) with the same number of
+#' replicates in each group. Trees are ordered by groups and then by replicates:
 #' \{group1+rep1, group1+rep2, ...\}. One test is performed for each group.
 #' @param trees2 A list of trees corresponding to the second condition. Trees
-#' are also structured in groups (or clusters) that are exactly the same than 
-#' for the first condition. The number of replicates in each group can be 
+#' are also structured in groups (or clusters) that are exactly the same than
+#' for the first condition. The number of replicates in each group can be
 #' different from that of \code{trees1}.
-#' @param replicates A numeric vector of length 2 with the number of replicates 
+#' @param replicates A numeric vector of length 2 with the number of replicates
 #' for each condition.
 #'
-#' @return An object of class \code{htest} with the following entries: \itemize{
+#' @return An object of class \code{treeTest} with the following entries: \itemize{
 #'   \item{p.value}{ the p-value for the treediff test.}
-#'   \item{statistic}{ the value of the Student's statistic of each leaf pair of 
+#'   \item{statistic}{ the value of the Student's statistic of each leaf pair of
 #'   the tree test.}
-#'   \item{p.value.indiv}{ the p-value of the Student's test for each leaf pair 
+#'   \item{p.value.indiv}{ the p-value of the Student's test for each leaf pair
 #'   of the tree test.}
-#'   \item{method}{ a character string indicating what type of test was 
+#'   \item{method}{ a character string indicating what type of test was
 #'   performed.}
-#'   \item{data.name}{ a character string giving the names of the tree 
+#'   \item{data.name}{ a character string giving the names of the tree
 #'   conditions.}
 #' }
 #'
@@ -35,9 +35,9 @@
 #' Sylvain Foissac \email{sylvain.foissac@inrae.fr}\cr
 #' Nathalie Vialaneix \email{nathalie.vialaneix@inrae.fr}
 #'
-#' @references Neuvial Pierre, Randriamihamison Nathanaël, Chavent Marie, 
-#' Foissac Sylvain and Vialaneix Nathalie (2023) Testing differences in 
-#' structure between families of trees. \emph{Preprint submitted for 
+#' @references Neuvial Pierre, Randriamihamison Nathanaël, Chavent Marie,
+#' Foissac Sylvain and Vialaneix Nathalie (2023) Testing differences in
+#' structure between families of trees. \emph{Preprint submitted for
 #' publication}.
 #'
 #' @export
@@ -47,31 +47,49 @@
 #' @importFrom dplyr summarise
 #' @importFrom limma squeezeVar
 #' @importFrom stats cophenetic
-#' @importFrom stats na.omit
 #' @importFrom stats pt
 #'
 #' @examples
 #'
-#' base_data <- matrix(rnorm(2000), nrow = 100, ncol = 200)
+#' leaves <- c(100,120,50,80)
 #'
-#' ## generates two sets of trees with 4 clusters
+#' trees <- sapply(leaves, FUN = function(l){
+#'   base_data <- matrix(rnorm(2000), nrow = l, ncol = 200)
+#'
+#' ## generates two sets of trees with 4 clusters with 100, 120, 50 and 80
+#' ## leaves respectively
 #' ## 4 replicates in the first condition and 6 in the second condition
-#' set1 <- replicate(16, sample(1:100, 50, replace = FALSE))
-#' set2 <- replicate(24, sample(101:200, 50, replace = FALSE))
 #'
-#' trees <- apply(cbind(set1, set2), 2, function(asample) {
-#'   samples <- base_data[, asample]
-#'   out <- hclust(dist(samples), method = "ward.D2")
-#' return(out)
+#'   set1 <- replicate(4, sample(1:100, 50, replace = FALSE))
+#'   set2 <- replicate(6, sample(101:200, 50, replace = FALSE))
+#'
+#'   trees1 <- apply(set1, 2, function(asample) {
+#'     samples <- base_data[, asample]
+#'     out <- hclust(dist(samples), method = "ward.D2")
+#'     return(out)
+#'   })
+#'
+#'   trees2 <- apply(set2, 2, function(asample) {
+#'     samples <- base_data[, asample]
+#'     out <- hclust(dist(samples), method = "ward.D2")
+#'     return(out)
+#'   })
+#'   return(list("trees1" = trees1, "trees2" = trees2))
 #' })
-#' trees1 <- trees[1:ncol(set1)]
-#' trees2 <- trees[1:ncol(set2) + ncol(set1)]
 #'
-#' tree_pvals <- treediff(trees1, trees2, replicates = c(4, 6))
+#' trees1 <- unlist(trees[1,], recursive = FALSE)
+#' trees2 <- unlist(trees[2,], recursive = FALSE)
+#' replicates = c(4, 6)
+#'
+#' tree_pvals <- treediff(trees1, trees2, replicates)
 #' ## 4 p-values, one for each cluster
 #' tree_pvals$p.value
 
 treediff <- function(trees1, trees2, replicates){
+  # Check if `replicates` is numeric vector
+  if (inherits(replicates, "numeric") != TRUE){
+    stop("`replicates` is not a numeric vector")
+  }
 
   # Check if the length of replicates is 2
   if (length(replicates) != 2){
@@ -81,7 +99,19 @@ treediff <- function(trees1, trees2, replicates){
   # Check if the number of clusters is equal for both conditions
   if (length(trees1) / replicates[1] != length(trees2) / replicates[2]) {
     stop(paste("The number of clusters is different between conditions (or",
-               "`replicates' is not correct)."))
+               "`replicates` is not correct)."))
+  }
+
+  # Check the number of leaves is the same for each cluster
+  tree_order1 <- lapply(trees1, "[[", "order")
+  leaves1 <- sapply(tree_order1, length)
+
+  tree_order2 <- lapply(trees2, "[[", "order")
+  leaves2 <- sapply(tree_order2, length)
+
+  if (!identical(unique(leaves1), unique(leaves2))){
+    stop("the number of leaves in one or more clusters is different between ",
+    "the two sets of trees.")
   }
 
   # Merge trees from both conditions
@@ -103,9 +133,10 @@ treediff <- function(trees1, trees2, replicates){
   outp <- compute_pvalue(outs$average_coph, outs$squeezed_var, replicates)
 
   # Aggregate p-values
-  out_aggr <- outp %>% group_by(cluster) %>%
-    summarise("p.value" = min(sort(p.value) / (1:p)) * p) %>%
-    unique()
+  out_aggr <- suppressWarnings(outp %>%
+    group_by(cluster) %>%
+    summarise("p.value" = min(sort(p.value) / (1:p)) * p, .groups = "keep") %>%
+    unique())
 
   # Store results in a list
   data_name <- paste(substitute(trees1), "and", substitute(trees2))
@@ -116,12 +147,21 @@ treediff <- function(trees1, trees2, replicates){
               "p.value.indiv" = outp$p.value)
 
   # Assign class to the list
-  class(out) <- "htest"
+  class(out) <- "treeTest"
 
   # Return result
   return(out)
 }
 
+print.treeTest <- function(x, prefix = "\t",...){
+  cat("\n")
+  cat(strwrap(x$method, prefix = prefix), sep = "\n")
+  cat("\n")
+  cat("data:  ", x$data.name, "\n", sep = "")
+  cat("p-value:  ", "\n", sep = "")
+  cat(strwrap(round(x$p.value, 4), prefix = prefix), sep = "\n")
+}
+
 compute_squeeze <- function(dist_coph, replicates) {
 
   # Calculate number of clusters
@@ -136,7 +176,7 @@ compute_squeeze <- function(dist_coph, replicates) {
 
   # Indices for each group
   set1 <- 1:length(clusters1)
-  set2 <- 1:length(clusters2) + length(col1)
+  set2 <- 1:length(clusters2) + length(set1)
 
   # Average per groups and conditions
   average_coph_trees1 <- lapply(unique(clusters1), function(acluster) {
@@ -151,25 +191,24 @@ compute_squeeze <- function(dist_coph, replicates) {
   # Merge average values and cluster vector
   cluster_length <- sapply(average_coph_trees1, length)
   average_coph <- data.frame("set1" = Reduce(c, average_coph_trees1),
-                             "set2" = Reduce(c, average_coph_trees2), 
+                             "set2" = Reduce(c, average_coph_trees2),
                              "cluster" = rep(unique(clusters1), cluster_length))
 
   # Calculate variance
   sq_average_coph <- sweep(average_coph[-3]^2, 2, replicates, "*")
 
   # Sum of squared values for each group
-  ## FIX IT: PROBLEM WITH VARIANCES HERE
   sum_sq_coph_trees1 <- lapply(unique(clusters1), function(acluster) {
     where_clust <- which(clusters1 == acluster)
-    colMeans(Reduce(rbind, dist_coph[where_clust])^2)
+    colSums(Reduce(rbind, dist_coph[where_clust])^2)
   })
   sum_sq_coph_trees2 <- lapply(unique(clusters1), function(acluster) {
     where_clust <- which(clusters2 == acluster) + length(clusters1)
-    colMeans(Reduce(rbind, dist_coph[where_clust])^2)
+    colSums(Reduce(rbind, dist_coph[where_clust])^2)
   })
-  
+
   sum_sq_coph <- data.frame("set1" = Reduce(c, sum_sq_coph_trees1),
-                            "set2" = Reduce(c, sum_sq_coph_trees2), 
+                            "set2" = Reduce(c, sum_sq_coph_trees2),
                             "cluster" = rep(unique(clusters1), cluster_length))
 
   variances <- sum_sq_coph[, 1:2] - sq_average_coph
diff --git a/man/treediff.Rd b/man/treediff.Rd
index 5a076f12aafe3befef66f348dd8d5f15082cf32f..cdaf4664bb699470ba923458daa1a9fd93853902 100644
--- a/man/treediff.Rd
+++ b/man/treediff.Rd
@@ -43,22 +43,37 @@ method. The p-values are obtained by the treediff method, as described in
 }
 \examples{
 
-base_data <- matrix(rnorm(2000), nrow = 100, ncol = 200)
+leaves <- c(100,120,50,80)
 
-## generates two sets of trees with 4 clusters
+trees <- sapply(leaves, FUN = function(l){
+  base_data <- matrix(rnorm(2000), nrow = l, ncol = 200)
+
+## generates two sets of trees with 4 clusters with 100, 120, 50 and 80
+## leaves respectively
 ## 4 replicates in the first condition and 6 in the second condition
-set1 <- replicate(16, sample(1:100, 50, replace = FALSE))
-set2 <- replicate(24, sample(101:200, 50, replace = FALSE))
 
-trees <- apply(cbind(set1, set2), 2, function(asample) {
-  samples <- base_data[, asample]
-  out <- hclust(dist(samples), method = "ward.D2")
-return(out)
+  set1 <- replicate(4, sample(1:100, 50, replace = FALSE))
+  set2 <- replicate(6, sample(101:200, 50, replace = FALSE))
+
+  trees1 <- apply(set1, 2, function(asample) {
+    samples <- base_data[, asample]
+    out <- hclust(dist(samples), method = "ward.D2")
+    return(out)
+  })
+
+  trees2 <- apply(set2, 2, function(asample) {
+    samples <- base_data[, asample]
+    out <- hclust(dist(samples), method = "ward.D2")
+    return(out)
+  })
+  return(list("trees1" = trees1, "trees2" = trees2))
 })
-trees1 <- trees[1:ncol(set1)]
-trees2 <- trees[1:ncol(set2) + ncol(set1)]
 
-tree_pvals <- treediff(trees1, trees2, replicates = c(4, 6))
+trees1 <- unlist(trees[1,], recursive = FALSE)
+trees2 <- unlist(trees[2,], recursive = FALSE)
+replicates = c(4, 6)
+
+tree_pvals <- treediff(trees1, trees2, replicates)
 ## 4 p-values, one for each cluster
 tree_pvals$p.value
 }
diff --git a/tests/testthat/test-treediff.R b/tests/testthat/test-treediff.R
index fe0968f7caa20f5eb8f651ee74d64cd7621381fc..66b335215c8d3be3a00c2577eab2fc506d8a886f 100644
--- a/tests/testthat/test-treediff.R
+++ b/tests/testthat/test-treediff.R
@@ -1,30 +1,34 @@
-set.seed(12081238) 
-base_data <- matrix(rnorm(2000), nrow = 100, ncol = 200)
-group1 <- replicate(20, sample(1:100, 50, replace = FALSE))
-group2 <- replicate(20, sample(101:200, 50, replace = FALSE))
-conditions <- factor(rep(c(1, 2), each = 20))
-trees <- apply(cbind(group1, group2), 2, function(asample) {
-  samples <- base_data[, asample]
-  out <- hclust(dist(samples), method = "ward.D2")
-  return(out)
+leaves <- c(100,120,50,80)
+
+trees <- sapply(leaves, FUN = function(l){
+  base_data <- matrix(rnorm(2000), nrow = l, ncol = 200)
+
+  set1 <- replicate(4, sample(1:100, 50, replace = FALSE))
+  set2 <- replicate(6, sample(101:200, 50, replace = FALSE))
+
+  trees1 <- apply(set1, 2, function(asample) {
+    samples <- base_data[, asample]
+    out <- hclust(dist(samples), method = "ward.D2")
+    return(out)
+  })
+
+  trees2 <- apply(set2, 2, function(asample) {
+    samples <- base_data[, asample]
+    out <- hclust(dist(samples), method = "ward.D2")
+    return(out)
+  })
+  return(list("trees1" = trees1, "trees2" = trees2))
 })
-trees1 <- trees[1:20]
-trees2 <- trees[21:40]
+
+trees1 <- unlist(trees[1,], recursive = FALSE)
+trees2 <- unlist(trees[2,], recursive = FALSE)
+replicates = c(4, 6)
 
 test_that("'treediff' works for simple cases", {
-  res <- treediff(trees1, trees2, nsim = 1e5)
-  expect_s3_class(res, "htest")
-  expect_s3_class(res, "tree_test")
-  
-  expect_named(res, c("statistic", "parameter", "p.value", 'p.value_Simes',
-                      "alternative", "method", "data.name", "res"))
-  
-  expect_true(is.numeric(res$statistic))
-  expect_true(is.numeric(res$p.value))
-})
+  res <- treediff(trees1, trees2, replicates)
+
+  expect_named(res, c("method", "data.name", "p.value",
+                      "statistic", "p.value.indiv"))
 
-test_that("'threshold.p' option works in a simple case (no truncation)", {
-  res <- treediff(trees1, trees2, threshold.p = 1, nsim = 100)
-  expect_equal(length(res$res$cophenetics[[1]]), 
-               length(res$res$truncated_indices))
+  expect_equal(length(res$statistic), length(res$p.value.indiv))
 })