Skip to main content

Understanding Linear Regression in Machine Learning

· 4 min read
Vadim Nicolai
Senior Software Engineer at Vitrifi

Introduction

Linear regression is a fundamental algorithm in supervised machine learning, widely used for predicting continuous outcomes. It models the relationship between a dependent variable and one or more independent variables by fitting a linear equation to observed data. This article delves into the components of linear regression, explaining how inputs, parameters, and the cost function work together to create a predictive model.

The Linear Regression Model

Model Equation

At the heart of linear regression is the model equation:

fw,b(x)=wx+bf_{w,b}(x) = w \cdot x + b
  • xx: The input feature or independent variable.
  • ww: The weight or coefficient, representing the slope of the line.
  • bb: The bias or intercept term, indicating where the line crosses the y-axis.
  • fw,b(x)f_{w,b}(x): The predicted output for a given input xx.

This equation represents a straight line in two-dimensional space, where the goal is to find the optimal values of ww and bb that minimize the difference between the predicted outputs and the actual outputs.

Inputs or Features

The inputs, also known as features, are the data fed into the model to make predictions. In the context of the linear regression model:

  • xx is the input feature that the model uses to predict an output yy.
  • The model expects xx as input to compute the predicted value fw,b(x)f_{w,b}(x).

Parameters

The parameters of the model are the variables that the learning algorithm adjusts during training:

  • ww (weight): Determines how much the input xx influences the output.
  • bb (bias): Allows the model to shift the line up or down to better fit the data.

These parameters are not inputs to the model; instead, they are learned from the data during the training process to minimize the prediction error.

Training the Model

Cost Function

To evaluate how well the model is performing, we use a cost function J(w,b)J(w, b), often defined as the mean squared error (MSE) between the predicted outputs and the actual outputs:

J(w,b)=12mi=1m(fw,b(x(i))y(i))2J(w, b) = \frac{1}{2m} \sum_{i=1}^{m} \left( f_{w,b}\left( x^{(i)} \right) - y^{(i)} \right)^2
  • mm: The number of training examples.
  • x(i)x^{(i)}: The ii-th input feature.
  • y(i)y^{(i)}: The actual output corresponding to x(i)x^{(i)}.

The cost function quantifies the error of the model; the objective is to find the parameters ww and bb that minimize J(w,b)J(w, b).

Gradient Descent

To minimize the cost function, we use the gradient descent algorithm, which iteratively updates the parameters in the direction of the steepest descent:

Update rule for ww:

wwαJ(w,b)ww \leftarrow w - \alpha \frac{\partial J(w, b)}{\partial w}

Update rule for bb:

bbαJ(w,b)bb \leftarrow b - \alpha \frac{\partial J(w, b)}{\partial b}
  • α\alpha: The learning rate, controlling the size of the steps taken to reach the minimum.

By updating ww and bb using the gradients of the cost function, the model progressively improves its predictions.

Model Evaluation

Interpreting the Cost Function Value

  • When J(w,b)J(w, b) is very close to zero, it indicates that the model's predictions are very close to the actual outputs in the training data.
  • A higher value of J(w,b)J(w, b) implies a larger error between the predicted and actual values, suggesting that the model may need more training or a different approach.

Evaluating on Test Data

After training, it's essential to evaluate the model's performance on a separate test dataset to ensure it generalizes well to unseen data. Common metrics include:

  • Mean Squared Error (MSE):

    MSE=1ni=1n(fw,b(xtest(i))ytest(i))2\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} \left( f_{w,b}\left( x_{\text{test}}^{(i)} \right) - y_{\text{test}}^{(i)} \right)^2
  • Coefficient of Determination (R² Score):

    R2=1i=1n(ytest(i)fw,b(xtest(i)))2i=1n(ytest(i)yˉtest)2R^2 = 1 - \frac{\sum_{i=1}^{n} \left( y_{\text{test}}^{(i)} - f_{w,b}\left( x_{\text{test}}^{(i)} \right) \right)^2}{\sum_{i=1}^{n} \left( y_{\text{test}}^{(i)} - \bar{y}_{\text{test}} \right)^2}
    • yˉtest\bar{y}_{\text{test}}: The mean of the actual test outputs.

An R² score close to 1 indicates that the model explains a large portion of the variance in the data.

Conclusion

Linear regression serves as a foundational tool in machine learning for understanding and predicting relationships between variables. By mastering the components of linear regression—such as the model equation, parameters, cost function, and optimization algorithm—you can build models that effectively predict continuous outcomes. This understanding also paves the way for exploring more complex models and algorithms in the field of machine learning.

For further reading on linear regression and its applications, consider exploring topics like regularization techniques, multivariate regression, and the assumptions underlying linear models.