Simple Linear Regression in Python with Statsmodels
Simple Linear Regression in Python with Statsmodels, Simple linear regression is one of the most fundamental techniques in statistics and data analysis.
Simple Linear Regression in Python with Statsmodels
It allows you to understand and quantify the relationship between two variables by modeling how changes in an independent variable influence a dependent variable.
Whether you’re working in finance, marketing, engineering, or data science, mastering simple linear regression is essential for making informed decisions based on data.
In this guide, we’ll walk you through how to perform simple linear regression using Python’s statsmodels library, a powerful tool that simplifies statistical modeling and provides detailed insights into your model’s performance.
First, you’ll learn how to set up your environment by installing the necessary libraries and importing them into your Python script.
You can install statsmodels with the command pip install statsmodels
.
After installation, import the essential libraries: numpy for numerical operations, pandas for data manipulation, statsmodels.api for statistical modeling, and matplotlib.pyplot for visualization.
Next, you’ll need a dataset to analyze.
For demonstration purposes, we’ll generate a simple synthetic dataset that explores the relationship between advertising spending and sales revenue.
# Creating a sample dataset
np.random.seed(425)
# Generate 50 random advertising budget values (independent variable X)
X = np.random.randint(10, 100, 50)
# Generate sales revenue with some noise (dependent variable Y)
Y = 5 + 1.5 * X + np.random.normal(0, 10, 50)
# Convert to a Pandas DataFrame
df = pd.DataFrame({'Advertising Budget': X, 'Sales Revenue': Y})
# Display the first five rows
print(df.head())
By setting a random seed for reproducibility, generate 50 random advertising budget values as your independent variable and model sales revenue with some added noise to simulate real-world data.
Convert this data into a pandas DataFrame for easier manipulation and display the first few rows to verify the dataset.
Preparing the data for statsmodels involves adding a constant term to the independent variable to account for the y-intercept in the regression equation.
Unlike scikit-learn, statsmodels does not include this by default, so you’ll need to do this manually using sm.add_constant()
.
# Add a constant (intercept) to the independent variable
X = sm.add_constant(df['Advertising Budget']) # Adds a column of ones
# Define the dependent variable
Y = df['Sales Revenue']
Define your dependent variable as the sales revenue column and include the constant in your independent variables.
Fitting the simple linear regression model is straightforward with statsmodels’ OLS (ordinary least squares) method.
# Create the OLS model
model = sm.OLS(Y, X)
# Fit the model
results = model.fit()
Create the model by passing your dependent and independent variables, then fit it using the model.fit()
method.
This process estimates the slope and intercept of the best-fit line by minimizing the sum of squared residuals.
The output includes valuable statistical measures such as R-squared, p-values, and confidence intervals, which help you evaluate the model’s accuracy and significance.
Once your model is fitted, generate a detailed summary of the regression results using the .summary()
method.
print(results.summary())
This report provides insights into how well your independent variable explains the variation in the dependent variable, the significance of the predictors, and other diagnostic metrics.
Analyzing this summary helps you assess the quality and reliability of your regression model.
Visualizing your regression results enhances understanding and interpretation.
# Scatter plot of actual data plt.scatter(df['Advertising Budget'], df['Sales Revenue'], color='blue', label="Actual Data") # Plot the regression line plt.plot(df['Advertising Budget'], results.predict(X), color='red', linewidth=2, label="Regression Line") # Labels and title plt.xlabel("Advertising Budget (in $1000s)") plt.ylabel("Sales Revenue (in $1000s)") plt.title("Advertising Budget vs. Sales Revenue (with Regression Line)") plt.legend() plt.grid() plt.show()
Create a scatter plot of the actual data points—advertising budget versus sales revenue—and overlay the predicted regression line generated by your model.
This visual representation allows you to quickly see how well the model fits the data and identify any patterns or anomalies.
A good model will have most data points close to the regression line, indicating a strong relationship.
Performing simple linear regression using statsmodels in Python is an efficient way to explore relationships between variables and derive meaningful insights from your data.
This technique is widely applicable across industries, including finance, marketing, engineering, and research.
It offers a transparent and comprehensive way to analyze the influence of one variable on another, backed by robust statistical metrics.
Whether you’re building predictive models or conducting exploratory data analysis, mastering simple linear regression is an invaluable skill that opens doors to advanced statistical modeling and data-driven decision-making.