3. Making a scatter with a trend line#
Combining a scatter plot with a trend line creates a powerful visual tool for analyzing relationships between variables while also illustrating overall trends in the data.
In a scatter plot, individual data points are plotted based on their values for two variables, providing insight into the correlation or distribution of those variables. When a trend line is added, it helps to summarize the overall direction of the data, showing whether there is a positive, negative, or no correlation between the variables. Trend lines can also help identify outliers or anomalies in the data that might not be immediately noticeable from just the scatter plot.
🚀 When to use them:
This combination is particularly useful when you want to see both the individual data points and the general pattern they form. For example, in regression analysis, a scatter plot with a trend line can visually depict how well a model fits the data, making it valuable in fields like economics, science, or marketing when examining relationships between variables (e.g., advertising spend vs. sales).
⚠️ Be aware:
However, these charts are less useful when there is little to no relationship between variables, as the trend line may be misleading or not informative. Additionally, with a large amount of data or heavily clustered points, the scatter plot can become crowded, making it difficult to interpret the results or spot individual data points clearly.
Getting ready#
In addition to plotly
, numpy
and pandas
, make sure the following Python libraries avaiable in your Python environment
statsmodels
scipy
You can install it using the command:
pip install statsmodels, scipy
For this recipe we will create two data sets
Import the Python modules
numpy
,pandas
. Import thenorm
object fromscipy.stats
. This object will allow us to generate random samples from a normal distribution. This will help us to create data sets to be used in this recipe.
import numpy as np
import pandas as pd
from scipy.stats import norm
Create two data sets to be used in this recipe:
data1
: which contains two variables,x
andy
, with a linear relationshipdata2
: which contains two variables,x
andy
, with a non-linear relationship
n = 200
x = np.linspace(0, 15, n)
epsilon = norm().rvs(n)
sigma = 2
y = 2*x + sigma*epsilon
data1 = pd.DataFrame({'x':x, 'y':y})
n = 200
x = np.linspace(0, 15, n)
epsilon = norm(loc=20, scale=100).rvs(n)
y = 0.5*x**3 + epsilon -10
data2 = pd.DataFrame({'x':x, 'y':y})
How to do it#
Import the
plotly.express
module aspx
import plotly.express as px
Make a simple scatter plot to illustrate the points in the
data1
data set using the functionscatter
df = data1
fig = px.scatter(df, x='x', y ='y',
height=600, width=800,
title='Just a simple scatter')
fig.show()
We can observe that there is a linear relationship between the variables!
Linear Trend#
Add a line that captures the linear relationship in the data. To do this, simply add the argument
trendline
and pass the stringols
. This will draw the line determined by the Ordinary Least Squares regression (OLS) method.
fig = px.scatter(df, x='x', y ='y',
trendline="ols",
height=600, width=800,
title='Scatter with OLS trend line')
fig.show()
Change the color of the trend line by using
trendline_color_overrride
fig = px.scatter(df, x='x', y ='y',
trendline_color_override="red",
trendline="ols",
height=600, width=800,
title='Scatter with OLS trend line')
fig.show()
Retrieve the results of the OLS algorithm by using the
plotly
functionget_trendline_result
and passing your figure object.
results_table = px.get_trendline_results(fig)
results_table
px_fit_results | |
---|---|
0 | <statsmodels.regression.linear_model.Regressio... |
Let’s check wha type of object this returns
type(results_table)
pandas.core.frame.DataFrame
It is a pandas DataFrame
Extract the object containing the results from the
DataFrame
. This is astatsmodels.regression.linear_model.RegressionResultsWrapper
object
results = results_table['px_fit_results'][0]
results
<statsmodels.regression.linear_model.RegressionResultsWrapper at 0x111c2ba50>
type(results)
statsmodels.regression.linear_model.RegressionResultsWrapper
Get the full details on the regression by using the method
summary
from theresults
object. This method returns aDataFrame
results.summary()
Dep. Variable: | y | R-squared: | 0.940 |
---|---|---|---|
Model: | OLS | Adj. R-squared: | 0.940 |
Method: | Least Squares | F-statistic: | 3095. |
Date: | Fri, 21 Feb 2025 | Prob (F-statistic): | 7.88e-123 |
Time: | 13:59:44 | Log-Likelihood: | -438.56 |
No. Observations: | 200 | AIC: | 881.1 |
Df Residuals: | 198 | BIC: | 887.7 |
Df Model: | 1 | ||
Covariance Type: | nonrobust |
coef | std err | t | P>|t| | [0.025 | 0.975] | |
---|---|---|---|---|---|---|
const | 0.4205 | 0.307 | 1.370 | 0.172 | -0.185 | 1.026 |
x1 | 1.9697 | 0.035 | 55.632 | 0.000 | 1.900 | 2.039 |
Omnibus: | 2.853 | Durbin-Watson: | 1.777 |
---|---|---|---|
Prob(Omnibus): | 0.240 | Jarque-Bera (JB): | 2.896 |
Skew: | -0.284 | Prob(JB): | 0.235 |
Kurtosis: | 2.842 | Cond. No. | 17.4 |
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Note that there is a similar method namee summary2
. This also returns a DataFrame
with a summary. However, this is a experimental version and as such it must be used with caution.
Non-Linear Trend#
Make a scatter plot to illustrate the points in the
data2
data set. Include the OLS regression line to contrast it against the data. It is clear that the data does not show a linear relationship
df = data2
fig = px.scatter(df, x='x', y ='y',
trendline="ols",
trendline_color_override="red",
height=600, width=800,
title='Scatter with OLS trend line')
fig.show()
Import the
statsmodels.formula.api
assmf
. This will help us to set a non-linear model based on the data indata2
import statsmodels.formula.api as smf
Fit a OLS non-linear model to the data by using the
smf.ols
and passing
formula
This is a sring which specifies the non-linear curve that we want to fit. In this case we are going to fit a cubic polynomialdata
TheDataFrame
with the data set to be fitted
model = smf.ols(formula='y ~ I(x**3)', data = df).fit()
predicted = model.predict(df.x)
Plot the scatter together with the curve given by the fitted polynomial evaluated in the
x
variable
fig = px.scatter(df, x='x', y ='y',
height=600, width=800,
title='Scatter + Fitted Polynomial')
fig.add_scatter(x=df.x, y =predicted, name="Fitted Polynomial")
fig.show()
Get the full details of the model by using the method
summary
model.summary()
Dep. Variable: | y | R-squared: | 0.961 |
---|---|---|---|
Model: | OLS | Adj. R-squared: | 0.961 |
Method: | Least Squares | F-statistic: | 4893. |
Date: | Fri, 21 Feb 2025 | Prob (F-statistic): | 1.44e-141 |
Time: | 13:59:44 | Log-Likelihood: | -1198.3 |
No. Observations: | 200 | AIC: | 2401. |
Df Residuals: | 198 | BIC: | 2407. |
Df Model: | 1 | ||
Covariance Type: | nonrobust |
coef | std err | t | P>|t| | [0.025 | 0.975] | |
---|---|---|---|---|---|---|
Intercept | 2.9814 | 9.163 | 0.325 | 0.745 | -15.089 | 21.052 |
I(x ** 3) | 0.4993 | 0.007 | 69.949 | 0.000 | 0.485 | 0.513 |
Omnibus: | 3.426 | Durbin-Watson: | 1.854 |
---|---|---|---|
Prob(Omnibus): | 0.180 | Jarque-Bera (JB): | 3.041 |
Skew: | 0.239 | Prob(JB): | 0.219 |
Kurtosis: | 3.370 | Cond. No. | 1.71e+03 |
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
[2] The condition number is large, 1.71e+03. This might indicate that there are
strong multicollinearity or other numerical problems.