Quidest?

Decision Tree Regression

decision tree example

A decision tree is a supervised learning model used to predict the class label of instances by learning simple decision rules inferred from data features. It consists of nodes representing decisions based on feature values, leading to leaf nodes that denote class labels.

Building a Decision Tree

The process involves recursively splitting the dataset based on feature values to maximize the homogeneity of the resulting subsets:

  1. In regression trees, the quality of a split is often evaluated using metrics like Mean Squared Error (MSE) or Mean Absolute Error (MAE): $MSE = \frac{1}{n} \sum_{i=1}^n (y_i - \overline{y})$
  2. Once the best feature is selected based on the chosen metric, the dataset is partitioned into subsets. This process is recursively applied to each subset, creating branches of the tree.
  3. The recursion continues until a stopping condition is met, such as:
    • No remaining features to split on.
    • The tree reaches a predefined maximum depth.

From Scratch

 1import pandas as pd
 2import numpy as np
 3
 4
 5def split_dataset(X, y, feature_index, threshold):
 6    left_mask = X[:, feature_index] <= threshold
 7    right_mask = X[:, feature_index] > threshold
 8    return X[left_mask], y[left_mask], X[right_mask], y[right_mask]
 9
10
11def mse(y):
12    return np.mean((y - np.mean(y))**2)
13
14
15def weighted_mse(y_left, y_right):
16    total_len = len(y_left) + len(y_right)
17    return (len(y_left)/total_len) * mse(y_left) + (len(y_right)/total_len) * mse(y_right)
18
19
20class DecisionTreeRegressor:
21    def __init__(self, max_depth, min_samples_split=2):
22        self.max_depth = max_depth
23        self.min_samples_split = min_samples_split
24        self.tree = None
25
26    def fit(self, X, y):
27        self.tree = self._build_tree(X, y)
28
29    def _build_tree(self, X, y, depth=0):
30        num_samples, num_features = X.shape
31        if (self.max_depth is not None and depth >= self.max_depth) or num_samples < self.min_samples_split or len(np.unique(y)) == 1:
32            leaf_value = np.mean(y)
33            return {'leaf': True, 'value': leaf_value}
34
35        best_mse = float('inf')
36        best_split = None
37        for feat in range(num_features):
38            thresholds = np.unique(X[:, feat])
39            for threshold in thresholds:
40                left_X, left_y, right_X, right_y = split_dataset(
41                    X, y, feat, threshold)
42                current_mse = weighted_mse(left_y, right_y)
43                if current_mse < best_mse:
44                    best_mse = current_mse
45                    best_split = {
46                        'feature_index': feat,
47                        'threshold': threshold,
48                        'X_left': left_X,
49                        'y_left': left_y,
50                        'X_right': right_X,
51                        'y_right': right_y
52                    }
53
54        if best_mse == float('inf'):
55            return {'leaf': True, 'value': np.mean(y)}
56
57        left_subtree = self._build_tree(
58            best_split['X_left'], best_split['y_left'], depth + 1)
59        right_subtree = self._build_tree(
60            best_split['X_right'], best_split['y_right'], depth + 1)
61        return {
62            'leaf': False,
63            'feature_index': best_split['feature_index'],
64            'threshold': best_split['threshold'],
65            'left': left_subtree,
66            'right': right_subtree
67        }
68
69    def predict(self, X):
70        return np.array([self._predict_sample(sample, self.tree) for sample in X])
71
72    def _predict_sample(self, sample, tree):
73        if tree['leaf']:
74            return tree['value']
75        feature_value = sample[tree['feature_index']]
76        if feature_value <= tree['threshold']:
77            return self._predict_sample(sample, tree['left'])
78        else:
79            return self._predict_sample(sample, tree['right'])

#machine learning #ml #machine_learning #programming #statistics #information gain #gini index #entropy #cart #regression