6. Understanding marginal plots#

Marginal plots are a type of data visualization that combines a scatter plot (or other two-variable plot) with additional plots along the margins, typically showing the distribution of each individual variable. These marginal plots are usually histograms, density plots, or box plots, placed along the x-axis and y-axis to visualize the univariate distributions of each variable independently, alongside their bivariate relationship in the main scatter plot. By doing so, marginal plots allow viewers to analyze both the joint distribution of two variables (via the scatter plot) and the individual distributions (via the marginal plots) within a single visualization.

Marginal plots are particularly useful because they provide a more comprehensive understanding of the data. While a scatter plot shows how two variables relate to each other, the marginal plots highlight how each variable is distributed on its own. This dual perspective helps detect patterns such as skewness, multimodality, or outliers in each variable, while also showing how these variables interact. For example, you might notice that one variable is normally distributed while the other is skewed, or you might spot clusters or gaps in the bivariate relationship, which would be more difficult to identify without the marginal distributions.

These plots are commonly used in exploratory data analysis, especially when examining the relationship between two continuous variables, such as in fields like economics, biology, or machine learning. However, marginal plots may not always be suitable. If you’re working with categorical data or when the univariate distributions are not of interest, the marginal plots may add unnecessary complexity. Additionally, in cases of high-dimensional data, focusing on two variables with marginal plots may oversimplify relationships that are dependent on multiple variables.

Getting ready#

In addition to plotly, numpy and pandas, make sure the scipy Python library avaiable in your Python environment You can install it using the command:

pip install scipy 
  1. Import the Python modules numpy, pandas. Import the norm object from scipy.stats. This object will allow us to generate random samples from a normal distribution. This will help us to create the data set to be used in this recipe.

import numpy as np
import pandas as pd
from scipy.stats import multivariate_normal
rv = multivariate_normal([1.0, 3.0], [[1.0, 0.3], [0.3, 0.5]])
n = 200
sample = rv.rvs(n)
data = pd.DataFrame(sample, columns=['X', 'Y'])

How to do it#

  1. Import the plotly.express module as px

import plotly.express as px
df= data
  1. Create a scatter with box-plot marginals by using the function scatter and specifying the arguments marginal_x and marginal_y as 'box'

fig = px.scatter(df, x="X", y="Y",
                         marginal_x="box", marginal_y="box",
                         height = 500, width = 800,
                         title='Sample from a Bi-variate Normal Distribution')
fig.show()
  1. Create a 2-D histogram with histogram marginals by using the function density_heatmap and specifying the arguments marginal_x and marginal_y as 'histogram'

fig = px.density_heatmap(df, x="X", y="Y",
                         marginal_x="histogram", marginal_y="histogram",
                         nbinsx= 25,
                         nbinsy=25,
                         histnorm='probability density',
                         height = 500, width = 800,
                         title='Sample from a Bi-variate Normal Distribution')
fig.show()
  1. Create a 2-D histogram with violin-plot marginals by using the function density_heatmap and specifying the arguments marginal_x and marginal_y as 'violin'

fig = px.density_heatmap(df, x="X", y="Y",
                         marginal_x="violin", marginal_y="violin",
                         nbinsx= 25,
                         nbinsy=25,
                         histnorm='probability density',
                         height = 500, width = 800,
                         title='Sample from a Bi-variate Normal Distribution')
fig.show()
  1. Create scatter plot showing different samples with violin-plot marginals by using the function scatter and specifying the arguments color to differentiable the samples and marginal_x and marginal_y as 'violin'

df = px.data.iris()
fig = px.scatter(df, x="sepal_length", y="sepal_width", color="species", 
                 height = 500, width = 800,
                 marginal_x="violin", marginal_y="violin",
                 title ="Iris Data: Sepal Width vs Length by Species"
                 )
fig.show()