3. Making a scatter with a trend line#

Combining a scatter plot with a trend line creates a powerful visual tool for analyzing relationships between variables while also illustrating overall trends in the data.

In a scatter plot, individual data points are plotted based on their values for two variables, providing insight into the correlation or distribution of those variables. When a trend line is added, it helps to summarize the overall direction of the data, showing whether there is a positive, negative, or no correlation between the variables. Trend lines can also help identify outliers or anomalies in the data that might not be immediately noticeable from just the scatter plot.

🚀 When to use them:

This combination is particularly useful when you want to see both the individual data points and the general pattern they form. For example, in regression analysis, a scatter plot with a trend line can visually depict how well a model fits the data, making it valuable in fields like economics, science, or marketing when examining relationships between variables (e.g., advertising spend vs. sales).

⚠️ Be aware:

However, these charts are less useful when there is little to no relationship between variables, as the trend line may be misleading or not informative. Additionally, with a large amount of data or heavily clustered points, the scatter plot can become crowded, making it difficult to interpret the results or spot individual data points clearly.

Getting ready#

In addition to plotly, numpy and pandas, make sure the following Python libraries avaiable in your Python environment

  • statsmodels

  • scipy

You can install it using the command:

pip install statsmodels, scipy 

For this recipe we will create two data sets

  1. Import the Python modules numpy, pandas. Import the norm object from scipy.stats. This object will allow us to generate random samples from a normal distribution. This will help us to create data sets to be used in this recipe.

import numpy as np
import pandas as pd
from scipy.stats import norm
  1. Create two data sets to be used in this recipe:

  • data1 : which contains two variables, x and y, with a linear relationship

  • data2 : which contains two variables, x and y, with a non-linear relationship

n = 200
x = np.linspace(0, 15, n)
epsilon = norm().rvs(n)
sigma = 2
y = 2*x + sigma*epsilon
data1 = pd.DataFrame({'x':x, 'y':y})
n = 200
x = np.linspace(0, 15, n)
epsilon = norm(loc=20, scale=100).rvs(n)
y = 0.5*x**3 + epsilon -10
data2 = pd.DataFrame({'x':x, 'y':y})

How to do it#

  1. Import the plotly.express module as px

import plotly.express as px
  1. Make a simple scatter plot to illustrate the points in the data1 data set using the function scatter

df = data1
fig = px.scatter(df, x='x', y ='y', 
                 height=600, width=800,
                 title='Just a simple scatter')
fig.show()

We can observe that there is a linear relationship between the variables!

Linear Trend#

  1. Add a line that captures the linear relationship in the data. To do this, simply add the argument trendline and pass the string ols. This will draw the line determined by the Ordinary Least Squares regression (OLS) method.

fig = px.scatter(df, x='x', y ='y', 
                 trendline="ols",
                 height=600, width=800,
                 title='Scatter with OLS trend line')
fig.show()
  1. Change the color of the trend line by using trendline_color_overrride

fig = px.scatter(df, x='x', y ='y', 
                 trendline_color_override="red",
                 trendline="ols", 
                 height=600, width=800,
                 title='Scatter with OLS trend line')
fig.show()
  1. Retrieve the results of the OLS algorithm by using the plotly function get_trendline_result and passing your figure object.

results_table = px.get_trendline_results(fig)
results_table
px_fit_results
0 <statsmodels.regression.linear_model.Regressio...

Let’s check wha type of object this returns

type(results_table)
pandas.core.frame.DataFrame

It is a pandas DataFrame

  1. Extract the object containing the results from the DataFrame. This is a statsmodels.regression.linear_model.RegressionResultsWrapper object

results = results_table['px_fit_results'][0]
results
<statsmodels.regression.linear_model.RegressionResultsWrapper at 0x111c2ba50>
type(results)
statsmodels.regression.linear_model.RegressionResultsWrapper
  1. Get the full details on the regression by using the method summary from the results object. This method returns a DataFrame

results.summary()
OLS Regression Results
Dep. Variable: y R-squared: 0.940
Model: OLS Adj. R-squared: 0.940
Method: Least Squares F-statistic: 3095.
Date: Fri, 21 Feb 2025 Prob (F-statistic): 7.88e-123
Time: 13:59:44 Log-Likelihood: -438.56
No. Observations: 200 AIC: 881.1
Df Residuals: 198 BIC: 887.7
Df Model: 1
Covariance Type: nonrobust
coef std err t P>|t| [0.025 0.975]
const 0.4205 0.307 1.370 0.172 -0.185 1.026
x1 1.9697 0.035 55.632 0.000 1.900 2.039
Omnibus: 2.853 Durbin-Watson: 1.777
Prob(Omnibus): 0.240 Jarque-Bera (JB): 2.896
Skew: -0.284 Prob(JB): 0.235
Kurtosis: 2.842 Cond. No. 17.4


Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.

Note that there is a similar method namee summary2. This also returns a DataFrame with a summary. However, this is a experimental version and as such it must be used with caution.

Non-Linear Trend#

  1. Make a scatter plot to illustrate the points in the data2 data set. Include the OLS regression line to contrast it against the data. It is clear that the data does not show a linear relationship

df = data2
fig = px.scatter(df, x='x', y ='y', 
                 trendline="ols", 
                 trendline_color_override="red",
                 height=600, width=800,
                 title='Scatter with OLS trend line')
fig.show()
  1. Import the statsmodels.formula.api as smf. This will help us to set a non-linear model based on the data in data2

import statsmodels.formula.api as smf
  1. Fit a OLS non-linear model to the data by using the smf.ols and passing

  • formula This is a sring which specifies the non-linear curve that we want to fit. In this case we are going to fit a cubic polynomial

  • data The DataFrame with the data set to be fitted

model = smf.ols(formula='y ~ I(x**3)', data = df).fit()
predicted = model.predict(df.x)
  1. Plot the scatter together with the curve given by the fitted polynomial evaluated in the x variable

fig = px.scatter(df, x='x', y ='y',
                 height=600, width=800,
                 title='Scatter + Fitted Polynomial')
fig.add_scatter(x=df.x, y =predicted, name="Fitted Polynomial")
fig.show()
  1. Get the full details of the model by using the method summary

model.summary()
OLS Regression Results
Dep. Variable: y R-squared: 0.961
Model: OLS Adj. R-squared: 0.961
Method: Least Squares F-statistic: 4893.
Date: Fri, 21 Feb 2025 Prob (F-statistic): 1.44e-141
Time: 13:59:44 Log-Likelihood: -1198.3
No. Observations: 200 AIC: 2401.
Df Residuals: 198 BIC: 2407.
Df Model: 1
Covariance Type: nonrobust
coef std err t P>|t| [0.025 0.975]
Intercept 2.9814 9.163 0.325 0.745 -15.089 21.052
I(x ** 3) 0.4993 0.007 69.949 0.000 0.485 0.513
Omnibus: 3.426 Durbin-Watson: 1.854
Prob(Omnibus): 0.180 Jarque-Bera (JB): 3.041
Skew: 0.239 Prob(JB): 0.219
Kurtosis: 3.370 Cond. No. 1.71e+03


Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
[2] The condition number is large, 1.71e+03. This might indicate that there are
strong multicollinearity or other numerical problems.