Figure S7

markers.path <- "path_to_supervised_annotation"
markers <- readRDS(markers.path)

markers.stem.path <- "path_to_stemness_markers"
markers.stem <- readRDS(markers.stem.path)
markers$PSC <- markers.stem$PSC
markers$ESC <- markers.stem$ESC
markers$NSPC <- markers.stem$NSPC


markers <- markers[stringr::str_replace_all(names(markers), "\\.", "_")]




# Figure S7A ---------
path.to.sample <- "path_to_scRNAseq_sample"
sample <- readRDS(path.to.sample)

out <- SCpubr::do_EnrichmentHeatmap(sample, 
                                    input_gene_list = markers, 
                                    flavor = "UCell", 
                                    return_object = TRUE,
                                    scale_scores = FALSE)

sample <- out$Object
sample <- Seurat::ScaleData(sample)
sample <- Seurat::RunPCA(sample, features = rownames(sample))
dims <- 1:ncol(sample@reductions$pca@cell.embeddings)
sample <- Seurat::RunUMAP(sample, dims = dims)
sample <- Seurat::FindNeighbors(sample, reduction = "pca", dims = dims)
sample <- Seurat::FindClusters(sample, reduction = "pca", dims = dims, resolution = 1.5)

colors.use <- list("IPC-like"                 = "#be920e",
                   "Other" = "grey90")

p1 <- SCpubr::do_DimPlot(sample = sample,
                         group.by = "Annotation",
                         colors.use = colors.use,
                         font.size = 16,
                         label.size = 4,
                         raster = TRUE,
                         raster.dpi = 2048,
                         pt.size = 16,
                         legend.icon.size = 8,
                         legend.ncol = 2,
                         legend.position = "bottom")

p3 <- SCpubr::do_FeaturePlot(sample = sample,
                             features = "Neuronal.IPC",
                             slot = "scale.data",
                             enforce_symmetry = TRUE,
                             font.size = 16,
                             label.size = 4,
                             raster = TRUE,
                             raster.dpi = 2048,
                             pt.size = 16,
                             order = TRUE,
                             legend.position = "bottom")

p2 <- SCpubr::do_FeaturePlot(sample = sample,
                             features = "Cycle",
                             slot = "scale.data",
                             enforce_symmetry = TRUE,
                             font.size = 16,
                             label.size = 4,
                             raster = TRUE,
                             raster.dpi = 2048,
                             pt.size = 16,
                             order = TRUE,
                             legend.position = "bottom")


p <- p1 | p3 | p2 



# Figure S7B ---------
path.to.sample <- "path_to_snSMARTseq_sample"
sample <- readRDS(path.to.sample)

out <- SCpubr::do_EnrichmentHeatmap(sample, 
                                    input_gene_list = markers, 
                                    flavor = "UCell", 
                                    return_object = TRUE,
                                    scale_scores = FALSE)

sample <- out$Object
sample <- Seurat::ScaleData(sample)
sample <- Seurat::RunPCA(sample, features = rownames(sample))
dims <- 1:ncol(sample@reductions$pca@cell.embeddings)
sample <- Seurat::RunUMAP(sample, dims = dims)
sample <- Seurat::FindNeighbors(sample, reduction = "pca", dims = dims)
sample <- Seurat::FindClusters(sample, reduction = "pca", dims = dims, resolution = 1.5)


colors.use <- list("IPC-like"                 = "#be920e",
                   "Other" = "grey90")
p1 <- SCpubr::do_DimPlot(sample = sample,
                         group.by = "Annotation",
                         colors.use = colors.use,
                         font.size = 16,
                         label.size = 4,
                         raster = TRUE,
                         raster.dpi = 2048,
                         pt.size = 16,
                         legend.icon.size = 8,
                         legend.ncol = 2,
                         legend.position = "bottom")

p3 <- SCpubr::do_FeaturePlot(sample = sample,
                             features = "Neuronal.IPC",
                             slot = "scale.data",
                             enforce_symmetry = TRUE,
                             font.size = 16,
                             label.size = 4,
                             raster = TRUE,
                             raster.dpi = 2048,
                             pt.size = 16,
                             order = TRUE,
                             legend.position = "bottom")

p2 <- SCpubr::do_FeaturePlot(sample = sample,
                             features = "Cycle",
                             slot = "scale.data",
                             enforce_symmetry = TRUE,
                             font.size = 16,
                             label.size = 4,
                             raster = TRUE,
                             raster.dpi = 2048,
                             pt.size = 16,
                             order = TRUE,
                             legend.position = "bottom")

p <- p1 | p3 | p2 

# Figure S7C ---------
path.to.sample <- "path_to_scSMARTseq_sample"
sample <- readRDS(path.to.sample)

out <- SCpubr::do_EnrichmentHeatmap(sample, 
                                    input_gene_list = markers, 
                                    flavor = "UCell", 
                                    return_object = TRUE,
                                    scale_scores = FALSE)

sample <- out$Object
sample <- Seurat::ScaleData(sample)
sample <- Seurat::RunPCA(sample, features = rownames(sample))
dims <- 1:ncol(sample@reductions$pca@cell.embeddings)
sample <- Seurat::RunUMAP(sample, dims = dims)
sample <- Seurat::FindNeighbors(sample, reduction = "pca", dims = dims)
sample <- Seurat::FindClusters(sample, reduction = "pca", dims = dims, resolution = 1.5)

colors.use <- list("IPC-like"                 = "#be920e",
                   "Other" = "grey90")
p1 <- SCpubr::do_DimPlot(sample = sample,
                         group.by = "Annotation",
                         colors.use = colors.use,
                         font.size = 16,
                         label.size = 4,
                         raster = TRUE,
                         raster.dpi = 2048,
                         pt.size = 16,
                         legend.icon.size = 8,
                         legend.ncol = 2,
                         legend.position = "bottom")

p3 <- SCpubr::do_FeaturePlot(sample = sample,
                             features = "Neuronal.IPC",
                             slot = "scale.data",
                             enforce_symmetry = TRUE,
                             font.size = 16,
                             label.size = 4,
                             raster = TRUE,
                             raster.dpi = 2048,
                             pt.size = 16,
                             order = TRUE,
                             legend.position = "bottom")

p2 <- SCpubr::do_FeaturePlot(sample = sample,
                             features = "Cycle",
                             slot = "scale.data",
                             enforce_symmetry = TRUE,
                             font.size = 16,
                             label.size = 4,
                             raster = TRUE,
                             raster.dpi = 2048,
                             pt.size = 16,
                             order = TRUE,
                             legend.position = "bottom")

p <- p1 | p3 | p2 


# Figure S7D
deconv <- readRDS(path.to.deconvolution.results)

group.by = "cell_type"
split.by = "ID"
font.size = 14
font.type = "sans"
plot.title.face = "bold"
plot.subtitle.face = "plain"
plot.caption.face = "italic"
axis.title.face = "bold"
axis.text.face = "plain"
grid.color = "white"
border.color = "black"
axis.text.x.angle = 45
flip = FALSE


predictions <- deconv$mat %>% 
               data.frame(check.names = FALSE) %>% 
               tibble::rownames_to_column(var = "ID") %>% 
               tidyr::pivot_longer(cols = -"ID",
                                   names_to = "cell_type",
                                   values_to = "proportion")

df.order <- predictions %>%
            tidyr::pivot_wider(id_cols = group.by,
                               names_from = split.by,
                               values_from = "proportion") %>%
            tibble::column_to_rownames(group.by) %>%
            as.matrix()
df.order[is.na(df.order)] <- 0

col_order <- colnames(df.order)[stats::hclust(stats::dist(t(df.order), method = "euclidean"), method = "ward.D")$order]
row_order <- c("Astrocytes",
               "Endothelial",
               "Microglia",
               "Neurons",
               "OPC",
               "Pericytes",
               
               "IPC-like",
               "CP-like",
               "Cilia-like",
               "RG-like",
               "OPC-like",
               "NPC-like",
               "Mesenchymal-like",
               "Immune-like",
               "Hypoxic",
               
               "NMF-MP2",
               "NMF-MP3",
               "NMF-MP5",
               "NMF-MP8",
               "SHH-1",
               "SHH-2",
               "SHH-3",
               "MYC-1",
               "Unannotated")

predictions <- predictions %>% 
               dplyr::mutate("{group.by}" := factor(.data[[group.by]], levels = row_order),
                             "{split.by}" := factor(.data[[split.by]], levels = col_order),
                             "proportion" = .data$proportion * 100)

colors.use <- c("Unannotated"              = "#C0C0C0",
                "TME" = "grey25",
                
                "IPC-like"                 = "#be920e",
                
                "CP-like"                  = "#be660e",
                "Cilia-like"               = "#be0e0e",
                
                "Mesenchymal-like"         = "#0ebe66",
                
                "RG-like"                  = "#0497c8",
                "NPC-like"                 = "#0466c8",
                "OPC-like"                 = "#0435c8",
                
                "Hypoxic"                  = "#92be0e",
                "Immune-like"              = "#920ebe",
                
                
                "Astrocytes"               = "#BA331CFF",
                "Neurons"                  = "#787F00FF",
                "OPC"                      = "#009257FF",
                "Microglia"                = "#0092AAFF", 
                "Endothelial"              = "#5E4CCDFF",
                "Pericytes"                = "#a32978")




colors.gradient <- SCpubr:::compute_continuous_palette(name = "YlGnBu",
                                                       use_viridis = FALSE,
                                                       direction = 1,
                                                       enforce_symmetry = FALSE)

axis.parameters <- SCpubr:::handle_axis(flip = flip,
                                        group.by = rep("A", length(unique(sample$Final_Annotation))),
                                        group = name,
                                        counter = 1,
                                        axis.text.x.angle = axis.text.x.angle,
                                        plot.title.face = plot.title.face,
                                        plot.subtitle.face = plot.subtitle.face,
                                        plot.caption.face = plot.caption.face,
                                        axis.title.face = axis.title.face,
                                        axis.text.face = axis.text.face,
                                        legend.title.face = "bold",
                                        legend.text.face = "plain")

p1 <- predictions %>% 
  dplyr::filter(.data$cell_type %in% c("Astrocytes",
                                              "Endothelial",
                                              "Microglia",
                                              "Neurons",
                                              "OPC",
                                              "Pericytes"),
                !(.data$ID %in% c("H049-031P",
                                  "H049-0P37",
                                  "H049-33S1",
                                  "H049-T5XL",
                                  "H049-WB9V",
                                  "H049-8XZZ",
                                  "H049-9U3Q",
                                  "H049-FG38",
                                  "H049-GP41",
                                  "H049-GWNW",
                                  "H049-3NB9",
                                  "H049-FEGQ",
                                  "H049-0CWK",
                                  "H049-NY74",
                                  "H049-WYZN",
                                  "H049-X973"))) %>% 
  ggplot2::ggplot(mapping = ggplot2::aes(x = if(base::isFALSE(flip)){.data[[group.by]]} else {.data[[split.by]]},
                                         y = if(base::isFALSE(flip)){.data[[split.by]]} else {.data[[group.by]]},
                                         fill = .data$proportion)) + 
  ggplot2::geom_tile(color = "white", linewidth = 0.5) +
  ggplot2::geom_text(ggplot2::aes(label = round(.data$proportion, 1), 
                                  color = ifelse(.data$proportion > 50, "white", "black")), 
                     size = 3) +
  ggplot2::scale_y_discrete(expand = c(0, 0)) +
  ggplot2::scale_x_discrete(expand = c(0, 0),
                            position = "top") +
  ggplot2::guides(y.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[split.by]]))),
                  x.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[group.by]])))) + 
  ggplot2::coord_equal() +
  ggplot2::scale_color_identity() + 
  ggplot2::scale_fill_gradientn(colors = colors.gradient,
                                na.value = "grey75",
                                name = "Percentage",
                                breaks = c(0, 25, 50, 75, 100),
                                labels = c("0", "25", "50", "75", "100"),
                                limits = c(0, 100)) + 
  ggplot2::xlab("TME") +
  ggplot2::ylab("ATRT-TYR") +
  ggplot2::theme_minimal(base_size = font.size) +
  ggplot2::theme(axis.ticks.x.bottom = ggplot2::element_blank(),
                 axis.ticks.x.top = axis.parameters$axis.ticks.x.top,
                 axis.ticks.y.left = axis.parameters$axis.ticks.y.left,
                 axis.ticks.y.right =ggplot2::element_blank(),
                 axis.text.y.left = axis.parameters$axis.text.y.left,
                 axis.text.y.right = ggplot2::element_blank(),
                 axis.text.x.top = axis.parameters$axis.text.x.top,
                 axis.text.x.bottom = ggplot2::element_blank(),
                 axis.title.x.bottom = axis.parameters$axis.title.x.bottom,
                 axis.title.x.top = ggplot2::element_text(color = "black", face = "bold"),
                 axis.title.y.right = ggplot2::element_blank(),
                 axis.title.y.left = axis.parameters$axis.title.y.left,
                 strip.background = axis.parameters$strip.background,
                 strip.clip = axis.parameters$strip.clip,
                 strip.text = axis.parameters$strip.text,
                 legend.position = "bottom",
                 axis.line = ggplot2::element_blank(),
                 plot.title = ggplot2::element_text(face = plot.title.face, hjust = 0),
                 plot.subtitle = ggplot2::element_text(face = plot.subtitle.face, hjust = 0),
                 plot.caption = ggplot2::element_text(face = plot.caption.face, hjust = 1),
                 plot.title.position = "plot",
                 panel.grid = ggplot2::element_blank(),
                 panel.grid.minor.y = ggplot2::element_line(color = "white", linewidth = 1),
                 text = ggplot2::element_text(family = font.type),
                 plot.caption.position = "plot",
                 legend.text = ggplot2::element_text(face = "plain", size = font.size),
                 legend.title = ggplot2::element_text(face = "bold", size = font.size),
                 legend.justification = "center",
                 plot.margin = ggplot2::margin(t = 0, r = 0, b = 0, l = 0, unit = "mm"),
                 panel.border = ggplot2::element_rect(fill = NA, color = border.color, linewidth = 1),
                 panel.grid.major = ggplot2::element_blank(),
                 plot.background = ggplot2::element_rect(fill = "white", color = "white"),
                 panel.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.spacing = ggplot2::unit(0, "cm"),
                 panel.spacing.x = ggplot2::unit(0, "cm"),
                 panel.spacing.y = ggplot2::unit(0, "cm"))

p1 <- SCpubr:::modify_continuous_legend(p = p1,
                                        legend.title = "Percentage",
                                        legend.aes = "fill",
                                        legend.type = "colorbar",
                                        legend.position = "bottom",
                                        legend.width = 1,
                                        legend.length = 12.5,
                                        legend.framewidth = 0.5,
                                        legend.tickwidth = 0.5,
                                        legend.framecolor = "grey50",
                                        legend.tickcolor = "white")


p2 <- predictions %>% 
  dplyr::filter(!(.data$cell_type %in% c("Astrocytes",
                                       "Endothelial",
                                       "Microglia",
                                       "Neurons",
                                       "OPC",
                                       "Pericytes")),
                !(.data$ID %in% c("H049-031P",
                                  "H049-0P37",
                                  "H049-33S1",
                                  "H049-T5XL",
                                  "H049-WB9V",
                                  "H049-8XZZ",
                                  "H049-9U3Q",
                                  "H049-FG38",
                                  "H049-GP41",
                                  "H049-GWNW",
                                  "H049-3NB9",
                                  "H049-FEGQ",
                                  "H049-0CWK",
                                  "H049-NY74",
                                  "H049-WYZN",
                                  "H049-X973"))) %>% 
  ggplot2::ggplot(mapping = ggplot2::aes(x = if(base::isFALSE(flip)){.data[[group.by]]} else {.data[[split.by]]},
                                         y = if(base::isFALSE(flip)){.data[[split.by]]} else {.data[[group.by]]},
                                         fill = .data$proportion)) + 
  ggplot2::geom_tile(color = "white", linewidth = 0.5) +
  ggplot2::geom_text(ggplot2::aes(label = round(.data$proportion, 1), 
                                  color = ifelse(.data$proportion > 50, "white", "black")), 
                     size = 3) +
  ggplot2::scale_y_discrete(expand = c(0, 0)) +
  ggplot2::scale_x_discrete(expand = c(0, 0),
                            position = "top") +
  ggplot2::guides(y.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[split.by]]))),
                  x.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[group.by]])))) + 
  ggplot2::coord_equal() +
  ggplot2::scale_color_identity() + 
  ggplot2::scale_fill_gradientn(colors = colors.gradient,
                                na.value = "grey75",
                                name = "Percentage",
                                breaks = c(0, 25, 50, 75, 100),
                                labels = c("0", "25", "50", "75", "100"),
                                limits = c(0, 100)) + 
  ggplot2::xlab("Tumor") +
  ggplot2::ylab(NULL) +
  ggplot2::theme_minimal(base_size = font.size) +
  ggplot2::theme(axis.ticks.x.bottom = ggplot2::element_blank(),
                 axis.ticks.x.top = axis.parameters$axis.ticks.x.top,
                 axis.ticks.y.left = axis.parameters$axis.ticks.y.left,
                 axis.ticks.y.right = ggplot2::element_line(color = "black"),
                 axis.text.y.left = axis.parameters$axis.text.y.left,
                 axis.text.y.right = ggplot2::element_text(color = "black"),
                 axis.text.x.top = axis.parameters$axis.text.x.top,
                 axis.text.x.bottom = ggplot2::element_blank(),
                 axis.title.x.bottom = axis.parameters$axis.title.x.bottom,
                 axis.title.x.top = ggplot2::element_text(color = "black", face = "bold"),
                 axis.title.y.right = ggplot2::element_blank(),
                 axis.title.y.left = axis.parameters$axis.title.y.left,
                 strip.background = axis.parameters$strip.background,
                 strip.clip = axis.parameters$strip.clip,
                 strip.text = axis.parameters$strip.text,
                 legend.position = "bottom",
                 axis.line = ggplot2::element_blank(),
                 plot.title = ggplot2::element_text(face = plot.title.face, hjust = 0),
                 plot.subtitle = ggplot2::element_text(face = plot.subtitle.face, hjust = 0),
                 plot.caption = ggplot2::element_text(face = plot.caption.face, hjust = 1),
                 plot.title.position = "plot",
                 panel.grid = ggplot2::element_blank(),
                 panel.grid.minor.y = ggplot2::element_line(color = "white", linewidth = 1),
                 text = ggplot2::element_text(family = font.type),
                 plot.caption.position = "plot",
                 legend.text = ggplot2::element_text(face = "plain", size = font.size),
                 legend.title = ggplot2::element_text(face = "bold", size = font.size),
                 legend.justification = "center",
                 plot.margin = ggplot2::margin(t = 0, r = 0, b = 0, l = 0, unit = "mm"),
                 panel.border = ggplot2::element_rect(fill = NA, color = border.color, linewidth = 1),
                 panel.grid.major = ggplot2::element_blank(),
                 plot.background = ggplot2::element_rect(fill = "white", color = "white"),
                 panel.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.spacing = ggplot2::unit(0, "cm"),
                 panel.spacing.x = ggplot2::unit(0, "cm"),
                 panel.spacing.y = ggplot2::unit(0, "cm"))

p2 <- SCpubr:::modify_continuous_legend(p = p2,
                                        legend.title = "Percentage",
                                        legend.aes = "fill",
                                        legend.type = "colorbar",
                                        legend.position = "bottom",
                                        legend.width = 1,
                                        legend.length = 12.5,
                                        legend.framewidth = 0.5,
                                        legend.tickwidth = 0.5,
                                        legend.framecolor = "grey50",
                                        legend.tickcolor = "white")



p3 <- predictions %>% 
  dplyr::filter(.data$cell_type %in% c("Astrocytes",
                                       "Endothelial",
                                       "Microglia",
                                       "Neurons",
                                       "OPC",
                                       "Pericytes"),
                .data$ID %in% c("H049-031P",
                                "H049-0P37",
                                "H049-33S1",
                                "H049-T5XL",
                                "H049-WB9V",
                                "H049-8XZZ",
                                "H049-9U3Q",
                                "H049-FG38",
                                "H049-GP41",
                                "H049-GWNW")) %>% 
  ggplot2::ggplot(mapping = ggplot2::aes(x = if(base::isFALSE(flip)){.data[[group.by]]} else {.data[[split.by]]},
                                         y = if(base::isFALSE(flip)){.data[[split.by]]} else {.data[[group.by]]},
                                         fill = .data$proportion)) + 
  ggplot2::geom_tile(color = "white", linewidth = 0.5) +
  ggplot2::geom_text(ggplot2::aes(label = round(.data$proportion, 1), 
                                  color = ifelse(.data$proportion > 50, "white", "black")), 
                     size = 3) +
  ggplot2::scale_y_discrete(expand = c(0, 0)) +
  ggplot2::scale_x_discrete(expand = c(0, 0),
                            position = "top") +
  ggplot2::guides(y.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[split.by]]))),
                  x.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[group.by]])))) + 
  ggplot2::coord_equal() +
  ggplot2::scale_color_identity() + 
  ggplot2::scale_fill_gradientn(colors = colors.gradient,
                                na.value = "grey75",
                                name = "Percentage",
                                breaks = c(0, 25, 50, 75, 100),
                                labels = c("0", "25", "50", "75", "100"),
                                limits = c(0, 100)) + 
  ggplot2::xlab(NULL) +
  ggplot2::ylab("ATRT-SHH") +
  ggplot2::theme_minimal(base_size = font.size) +
  ggplot2::theme(axis.ticks.x.bottom = ggplot2::element_blank(),
                 axis.ticks.x.top = axis.parameters$axis.ticks.x.top,
                 axis.ticks.y.left = axis.parameters$axis.ticks.y.left,
                 axis.ticks.y.right =ggplot2::element_blank(),
                 axis.text.y.left = axis.parameters$axis.text.y.left,
                 axis.text.y.right = ggplot2::element_blank(),
                 axis.text.x.top = axis.parameters$axis.text.x.top,
                 axis.text.x.bottom = ggplot2::element_blank(),
                 axis.title.x.bottom = axis.parameters$axis.title.x.bottom,
                 axis.title.x.top = ggplot2::element_text(color = "black", face = "bold"),
                 axis.title.y.right = ggplot2::element_blank(),
                 axis.title.y.left = axis.parameters$axis.title.y.left,
                 strip.background = axis.parameters$strip.background,
                 strip.clip = axis.parameters$strip.clip,
                 strip.text = axis.parameters$strip.text,
                 legend.position = "bottom",
                 axis.line = ggplot2::element_blank(),
                 plot.title = ggplot2::element_text(face = plot.title.face, hjust = 0),
                 plot.subtitle = ggplot2::element_text(face = plot.subtitle.face, hjust = 0),
                 plot.caption = ggplot2::element_text(face = plot.caption.face, hjust = 1),
                 plot.title.position = "plot",
                 panel.grid = ggplot2::element_blank(),
                 panel.grid.minor.y = ggplot2::element_line(color = "white", linewidth = 1),
                 text = ggplot2::element_text(family = font.type),
                 plot.caption.position = "plot",
                 legend.text = ggplot2::element_text(face = "plain", size = font.size),
                 legend.title = ggplot2::element_text(face = "bold", size = font.size),
                 legend.justification = "center",
                 plot.margin = ggplot2::margin(t = 0, r = 0, b = 0, l = 0, unit = "mm"),
                 panel.border = ggplot2::element_rect(fill = NA, color = border.color, linewidth = 1),
                 panel.grid.major = ggplot2::element_blank(),
                 plot.background = ggplot2::element_rect(fill = "white", color = "white"),
                 panel.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.spacing = ggplot2::unit(0, "cm"),
                 panel.spacing.x = ggplot2::unit(0, "cm"),
                 panel.spacing.y = ggplot2::unit(0, "cm"))

p3 <- SCpubr:::modify_continuous_legend(p = p3,
                                        legend.title = "Percentage",
                                        legend.aes = "fill",
                                        legend.type = "colorbar",
                                        legend.position = "bottom",
                                        legend.width = 1,
                                        legend.length = 12.5,
                                        legend.framewidth = 0.5,
                                        legend.tickwidth = 0.5,
                                        legend.framecolor = "grey50",
                                        legend.tickcolor = "white")


p4 <- predictions %>% 
  dplyr::filter(!(.data$cell_type %in% c("Astrocytes",
                                         "Endothelial",
                                         "Microglia",
                                         "Neurons",
                                         "OPC",
                                         "Pericytes")),
                .data$ID %in% c("H049-031P",
                                "H049-0P37",
                                "H049-33S1",
                                "H049-T5XL",
                                "H049-WB9V",
                                "H049-8XZZ",
                                "H049-9U3Q",
                                "H049-FG38",
                                "H049-GP41",
                                "H049-GWNW")) %>% 
  ggplot2::ggplot(mapping = ggplot2::aes(x = if(base::isFALSE(flip)){.data[[group.by]]} else {.data[[split.by]]},
                                         y = if(base::isFALSE(flip)){.data[[split.by]]} else {.data[[group.by]]},
                                         fill = .data$proportion)) + 
  ggplot2::geom_tile(color = "white", linewidth = 0.5) +
  ggplot2::geom_text(ggplot2::aes(label = round(.data$proportion, 1), 
                                  color = ifelse(.data$proportion > 50, "white", "black")), 
                     size = 3) +
  ggplot2::scale_y_discrete(expand = c(0, 0)) +
  ggplot2::scale_x_discrete(expand = c(0, 0),
                            position = "top") +
  ggplot2::guides(y.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[split.by]]))),
                  x.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[group.by]])))) + 
  ggplot2::coord_equal() +
  ggplot2::scale_color_identity() + 
  ggplot2::scale_fill_gradientn(colors = colors.gradient,
                                na.value = "grey75",
                                name = "Percentage",
                                breaks = c(0, 25, 50, 75, 100),
                                labels = c("0", "25", "50", "75", "100"),
                                limits = c(0, 100)) + 
  ggplot2::xlab(NULL) +
  ggplot2::ylab(NULL) +
  ggplot2::theme_minimal(base_size = font.size) +
  ggplot2::theme(axis.ticks.x.bottom = ggplot2::element_blank(),
                 axis.ticks.x.top = axis.parameters$axis.ticks.x.top,
                 axis.ticks.y.left = axis.parameters$axis.ticks.y.left,
                 axis.ticks.y.right = ggplot2::element_line(color = "black"),
                 axis.text.y.left = axis.parameters$axis.text.y.left,
                 axis.text.y.right = ggplot2::element_text(color = "black"),
                 axis.text.x.top = axis.parameters$axis.text.x.top,
                 axis.text.x.bottom = ggplot2::element_blank(),
                 axis.title.x.bottom = axis.parameters$axis.title.x.bottom,
                 axis.title.x.top = ggplot2::element_text(color = "black", face = "bold"),
                 axis.title.y.right = ggplot2::element_blank(),
                 axis.title.y.left = axis.parameters$axis.title.y.left,
                 strip.background = axis.parameters$strip.background,
                 strip.clip = axis.parameters$strip.clip,
                 strip.text = axis.parameters$strip.text,
                 legend.position = "bottom",
                 axis.line = ggplot2::element_blank(),
                 plot.title = ggplot2::element_text(face = plot.title.face, hjust = 0),
                 plot.subtitle = ggplot2::element_text(face = plot.subtitle.face, hjust = 0),
                 plot.caption = ggplot2::element_text(face = plot.caption.face, hjust = 1),
                 plot.title.position = "plot",
                 panel.grid = ggplot2::element_blank(),
                 panel.grid.minor.y = ggplot2::element_line(color = "white", linewidth = 1),
                 text = ggplot2::element_text(family = font.type),
                 plot.caption.position = "plot",
                 legend.text = ggplot2::element_text(face = "plain", size = font.size),
                 legend.title = ggplot2::element_text(face = "bold", size = font.size),
                 legend.justification = "center",
                 plot.margin = ggplot2::margin(t = 0, r = 0, b = 0, l = 0, unit = "mm"),
                 panel.border = ggplot2::element_rect(fill = NA, color = border.color, linewidth = 1),
                 panel.grid.major = ggplot2::element_blank(),
                 plot.background = ggplot2::element_rect(fill = "white", color = "white"),
                 panel.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.spacing = ggplot2::unit(0, "cm"),
                 panel.spacing.x = ggplot2::unit(0, "cm"),
                 panel.spacing.y = ggplot2::unit(0, "cm"))

p4 <- SCpubr:::modify_continuous_legend(p = p4,
                                        legend.title = "Percentage",
                                        legend.aes = "fill",
                                        legend.type = "colorbar",
                                        legend.position = "bottom",
                                        legend.width = 1,
                                        legend.length = 12.5,
                                        legend.framewidth = 0.5,
                                        legend.tickwidth = 0.5,
                                        legend.framecolor = "grey50",
                                        legend.tickcolor = "white")


p5 <- predictions %>% 
  dplyr::filter(.data$cell_type %in% c("Astrocytes",
                                       "Endothelial",
                                       "Microglia",
                                       "Neurons",
                                       "OPC",
                                       "Pericytes"),
                .data$ID %in% c("H049-3NB9",
                                "H049-FEGQ",
                                "H049-0CWK",
                                "H049-NY74",
                                "H049-WYZN",
                                "H049-X973")) %>% 
  ggplot2::ggplot(mapping = ggplot2::aes(x = if(base::isFALSE(flip)){.data[[group.by]]} else {.data[[split.by]]},
                                         y = if(base::isFALSE(flip)){.data[[split.by]]} else {.data[[group.by]]},
                                         fill = .data$proportion)) + 
  ggplot2::geom_tile(color = "white", linewidth = 0.5) +
  ggplot2::geom_text(ggplot2::aes(label = round(.data$proportion, 1), 
                                  color = ifelse(.data$proportion > 50, "white", "black")), 
                     size = 3) +
  ggplot2::scale_y_discrete(expand = c(0, 0)) +
  ggplot2::scale_x_discrete(expand = c(0, 0),
                            position = "top") +
  ggplot2::guides(y.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[split.by]]))),
                  x.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[group.by]])))) + 
  ggplot2::coord_equal() +
  ggplot2::scale_color_identity() + 
  ggplot2::scale_fill_gradientn(colors = colors.gradient,
                                na.value = "grey75",
                                name = "Percentage",
                                breaks = c(0, 25, 50, 75, 100),
                                labels = c("0", "25", "50", "75", "100"),
                                limits = c(0, 100)) + 
  ggplot2::xlab(NULL) +
  ggplot2::ylab("ATRT-MYC") +
  ggplot2::theme_minimal(base_size = font.size) +
  ggplot2::theme(axis.ticks.x.bottom = axis.parameters$axis.ticks.x.bottom,
                 axis.ticks.x.top = axis.parameters$axis.ticks.x.top,
                 axis.ticks.y.left = axis.parameters$axis.ticks.y.left,
                 axis.ticks.y.right =ggplot2::element_blank(),
                 axis.text.y.left = axis.parameters$axis.text.y.left,
                 axis.text.y.right = ggplot2::element_blank(),
                 axis.text.x.top = axis.parameters$axis.text.x.top,
                 axis.text.x.bottom = axis.parameters$axis.text.x.bottom,
                 axis.title.x.bottom = axis.parameters$axis.title.x.bottom,
                 axis.title.x.top = ggplot2::element_text(color = "black", face = "bold"),
                 axis.title.y.right = ggplot2::element_blank(),
                 axis.title.y.left = axis.parameters$axis.title.y.left,
                 strip.background = axis.parameters$strip.background,
                 strip.clip = axis.parameters$strip.clip,
                 strip.text = axis.parameters$strip.text,
                 legend.position = "bottom",
                 axis.line = ggplot2::element_blank(),
                 plot.title = ggplot2::element_text(face = plot.title.face, hjust = 0),
                 plot.subtitle = ggplot2::element_text(face = plot.subtitle.face, hjust = 0),
                 plot.caption = ggplot2::element_text(face = plot.caption.face, hjust = 1),
                 plot.title.position = "plot",
                 panel.grid = ggplot2::element_blank(),
                 panel.grid.minor.y = ggplot2::element_line(color = "white", linewidth = 1),
                 text = ggplot2::element_text(family = font.type),
                 plot.caption.position = "plot",
                 legend.text = ggplot2::element_text(face = "plain", size = font.size),
                 legend.title = ggplot2::element_text(face = "bold", size = font.size),
                 legend.justification = "center",
                 plot.margin = ggplot2::margin(t = 0, r = 0, b = 0, l = 0, unit = "mm"),
                 panel.border = ggplot2::element_rect(fill = NA, color = border.color, linewidth = 1),
                 panel.grid.major = ggplot2::element_blank(),
                 plot.background = ggplot2::element_rect(fill = "white", color = "white"),
                 panel.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.spacing = ggplot2::unit(0, "cm"),
                 panel.spacing.x = ggplot2::unit(0, "cm"),
                 panel.spacing.y = ggplot2::unit(0, "cm"))

p5 <- SCpubr:::modify_continuous_legend(p = p5,
                                        legend.title = "Percentage",
                                        legend.aes = "fill",
                                        legend.type = "colorbar",
                                        legend.position = "bottom",
                                        legend.width = 1,
                                        legend.length = 12.5,
                                        legend.framewidth = 0.5,
                                        legend.tickwidth = 0.5,
                                        legend.framecolor = "grey50",
                                        legend.tickcolor = "white")


p6 <- predictions %>% 
  dplyr::filter(!(.data$cell_type %in% c("Astrocytes",
                                         "Endothelial",
                                         "Microglia",
                                         "Neurons",
                                         "OPC",
                                         "Pericytes")),
                .data$ID %in% c("H049-3NB9",
                                "H049-FEGQ",
                                "H049-0CWK",
                                "H049-NY74",
                                "H049-WYZN",
                                "H049-X973")) %>% 
  ggplot2::ggplot(mapping = ggplot2::aes(x = if(base::isFALSE(flip)){.data[[group.by]]} else {.data[[split.by]]},
                                         y = if(base::isFALSE(flip)){.data[[split.by]]} else {.data[[group.by]]},
                                         fill = .data$proportion)) + 
  ggplot2::geom_tile(color = "white", linewidth = 0.5) +
  ggplot2::geom_text(ggplot2::aes(label = round(.data$proportion, 1), 
                                  color = ifelse(.data$proportion > 50, "white", "black")), 
                     size = 3) +
  ggplot2::scale_y_discrete(expand = c(0, 0)) +
  ggplot2::scale_x_discrete(expand = c(0, 0),
                            position = "top") +
  ggplot2::guides(y.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[split.by]]))),
                  x.sec = SCpubr:::guide_axis_label_trans(~paste0(levels(.data[[group.by]])))) + 
  ggplot2::coord_equal() +
  ggplot2::scale_color_identity() + 
  ggplot2::scale_fill_gradientn(colors = colors.gradient,
                                na.value = "grey75",
                                name = "Percentage",
                                breaks = c(0, 25, 50, 75, 100),
                                labels = c("0", "25", "50", "75", "100"),
                                limits = c(0, 100)) + 
  ggplot2::xlab(NULL) +
  ggplot2::ylab(NULL) +
  ggplot2::theme_minimal(base_size = font.size) +
  ggplot2::theme(axis.ticks.x.bottom = axis.parameters$axis.ticks.x.bottom,
                 axis.ticks.x.top = axis.parameters$axis.ticks.x.top,
                 axis.ticks.y.left = axis.parameters$axis.ticks.y.left,
                 axis.ticks.y.right = ggplot2::element_line(color = "black"),
                 axis.text.y.left = axis.parameters$axis.text.y.left,
                 axis.text.y.right = ggplot2::element_text(color = "black"),
                 axis.text.x.top = axis.parameters$axis.text.x.top,
                 axis.text.x.bottom = axis.parameters$axis.text.x.bottom,
                 axis.title.x.bottom = axis.parameters$axis.title.x.bottom,
                 axis.title.x.top = ggplot2::element_text(color = "black", face = "bold"),
                 axis.title.y.right = ggplot2::element_blank(),
                 axis.title.y.left = axis.parameters$axis.title.y.left,
                 strip.background = axis.parameters$strip.background,
                 strip.clip = axis.parameters$strip.clip,
                 strip.text = axis.parameters$strip.text,
                 legend.position = "bottom",
                 axis.line = ggplot2::element_blank(),
                 plot.title = ggplot2::element_text(face = plot.title.face, hjust = 0),
                 plot.subtitle = ggplot2::element_text(face = plot.subtitle.face, hjust = 0),
                 plot.caption = ggplot2::element_text(face = plot.caption.face, hjust = 1),
                 plot.title.position = "plot",
                 panel.grid = ggplot2::element_blank(),
                 panel.grid.minor.y = ggplot2::element_line(color = "white", linewidth = 1),
                 text = ggplot2::element_text(family = font.type),
                 plot.caption.position = "plot",
                 legend.text = ggplot2::element_text(face = "plain", size = font.size),
                 legend.title = ggplot2::element_text(face = "bold", size = font.size),
                 legend.justification = "center",
                 plot.margin = ggplot2::margin(t = 0, r = 0, b = 0, l = 0, unit = "mm"),
                 panel.border = ggplot2::element_rect(fill = NA, color = border.color, linewidth = 1),
                 panel.grid.major = ggplot2::element_blank(),
                 plot.background = ggplot2::element_rect(fill = "white", color = "white"),
                 panel.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.spacing = ggplot2::unit(0, "cm"),
                 panel.spacing.x = ggplot2::unit(0, "cm"),
                 panel.spacing.y = ggplot2::unit(0, "cm"))

p6 <- SCpubr:::modify_continuous_legend(p = p6,
                                        legend.title = "Percentage",
                                        legend.aes = "fill",
                                        legend.type = "colorbar",
                                        legend.position = "bottom",
                                        legend.width = 1,
                                        legend.length = 12.5,
                                        legend.framewidth = 0.5,
                                        legend.tickwidth = 0.5,
                                        legend.framecolor = "grey50",
                                        legend.tickcolor = "white")




p <- patchwork::wrap_plots(A = p1 | p2,
                           B = p3 | p4,
                           C = p5 | p6,
                           ncol = 1,
                           heights = c(9, 10, 6),
                           guides = "collect")  + 
  patchwork::plot_annotation(theme = ggplot2::theme(legend.position = "bottom",
                                                    plot.margin = ggplot2::margin(t = 0, r = 0, b = 0, l = 0, unit = "mm")))



proportions.barplot <- as.data.frame(predictions)
proportions.barplot.tme <- proportions.barplot[proportions.barplot$cell_type %in%  c("Astrocytes", "Neurons", "OPC", "Microglia", "Endothelial", "Pericytes"), ]
proportions.barplot.tme <- proportions.barplot.tme %>% 
                           dplyr::group_by(ID) %>% 
                           dplyr::summarise("cell_type" = "TME",
                                            "proportion" = sum(.data$proportion))

proportions.barplot.tumor <- proportions.barplot[!(proportions.barplot$cell_type %in%  c("Astrocytes", "Neurons", "OPC", "Microglia", "Endothelial", "Pericytes")), ]

proportions.barplot <- rbind(proportions.barplot.tme, proportions.barplot.tumor)

proportions.barplot$cell_type <- as.character(proportions.barplot$cell_type)
proportions.barplot$cell_type[proportions.barplot$cell_type %in%  c("Astrocytes", "Neurons", "OPC", "Microglia", "Endothelial", "Pericytes")] <- "TME"

proportions.barplot$subtype <- "ATRT-TYR"
proportions.barplot$subtype[proportions.barplot$ID %in% c("H049-031P",
                                                               "H049-0P37",
                                                               "H049-33S1",
                                                               "H049-T5XL",
                                                               "H049-WB9V",
                                                               "H049-8XZZ",
                                                               "H049-9U3Q",
                                                               "H049-FG38",
                                                               "H049-GP41",
                                                               "H049-GWNW")] <- "ATRT-SHH"

proportions.barplot$subtype[proportions.barplot$ID %in% c("H049-3NB9",
                                                               "H049-FEGQ",
                                                               "H049-0CWK",
                                                               "H049-NY74",
                                                               "H049-WYZN",
                                                               "H049-X973")] <- "ATRT-MYC"

proportions.barplot$subtype <- factor(proportions.barplot$subtype, levels = c("ATRT-TYR", "ATRT-SHH", "ATRT-MYC"))

id.order <- proportions.barplot %>% dplyr::filter(.data$cell_type == "IPC-like") %>% dplyr::arrange(dplyr::desc(.data$proportion)) %>% dplyr::pull(.data$ID)
proportions.barplot$ID <- factor(proportions.barplot$ID, levels = rev(id.order))

celltype.order <- c("TME", "Unannotated",
                    "Immune-like", "Hypoxic",
                    "CP-like", "Cilia-like",
                    "RG-like", "NPC-like", "OPC-like",
                    "Mesenchymal-like",
                    "IPC-like")

proportions.barplot$cell_type <- factor(proportions.barplot$cell_type, levels = rev(celltype.order))

colors.use <- c("Unannotated"              = "#C0C0C0",
                "TME" = "grey25",
                
                "IPC-like"                 = "#be920e",
                
                "CP-like"                  = "#be660e",
                "Cilia-like"               = "#be0e0e",
                
                "Mesenchymal-like"         = "#0ebe66",
                
                "RG-like"                  = "#0497c8",
                "NPC-like"                 = "#0466c8",
                "OPC-like"                 = "#0435c8",
                
                "Hypoxic"                  = "#92be0e",
                "Immune-like"              = "#920ebe",
                
                
                "Astrocytes"               = "#BA331CFF",
                "Neurons"                  = "#787F00FF",
                "OPC"                      = "#009257FF",
                "Microglia"                = "#0092AAFF", 
                "Endothelial"              = "#5E4CCDFF",
                "Pericytes"                = "#a32978")


p0 <- proportions.barplot %>% 
  ggplot2::ggplot(mapping = ggplot2::aes(y = .data$ID,
                                         x = .data$proportion,
                                         fill = .data$cell_type)) + 
  ggplot2::geom_col(color = "black", position = "fill", linewidth = 0.25) + 
  ggplot2::scale_fill_manual(values = colors.use, name = "") + 
  ggplot2::facet_grid(rows = ggplot2::vars(subtype), scales = "free", space = "free") + 
  ggplot2::theme_minimal(base_size = 20) +
  ggplot2::labs(y = "ID",
                x = "Proportion") +
  ggplot2::theme(axis.title = ggplot2::element_text(color = "black",
                                                    face = "bold"),
                 axis.line.x = ggplot2::element_blank(),
                 axis.line.y = ggplot2::element_line(color = "black"),
                 axis.text.x = ggplot2::element_text(color = "black", face = "plain", angle = 0),
                 axis.text.y = ggplot2::element_text(color = "black", face = "plain"),
                 axis.ticks.y = ggplot2::element_line(color = "black"),
                 axis.ticks.x = ggplot2::element_line(color = "black"),
                 plot.title.position = "plot",
                 plot.title = ggplot2::element_text(face = "bold", hjust = 0.5),
                 plot.subtitle = ggplot2::element_text(face = "plain", hjust = 0),
                 plot.caption = ggplot2::element_text(face = "italic", hjust = 1),
                 panel.grid = ggplot2::element_blank(),
                 text = ggplot2::element_text(family = "sans"),
                 plot.caption.position = "plot",
                 legend.text = ggplot2::element_text(face = "plain"),
                 legend.position = "bottom",
                 legend.title = ggplot2::element_text(face = "bold"),
                 legend.justification = "center",
                 plot.margin = ggplot2::margin(t = 0, r = 10, b = 0, l = 10),
                 plot.background = ggplot2::element_rect(fill = "white", color = "white"),
                 panel.background = ggplot2::element_rect(fill = "white", color = "white"),
                 legend.background = ggplot2::element_rect(fill = "white", color = "white"),
                 strip.text = ggplot2::element_text(color = "black", face = "bold"),
                 strip.background = ggplot2::element_blank())