{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c9cab631",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import datasets,tree\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score,precision_score, recall_score,classification_report\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f39df88d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from collections import Counter\n",
    "\n",
    "class DecisionTree:\n",
    "    def __init__(self, max_depth=None, min_samples_split=2):\n",
    "        self.max_depth = max_depth  # 决策树最大深度\n",
    "        self.min_samples_split = min_samples_split  # 节点分裂所需的最小样本数\n",
    "        self.tree = None  # 根节点\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        \n",
    "        #########################################\n",
    "        #  add your code\n",
    "        self.n_classes_ =   # 类别数\n",
    "        self.n_features_ =  # 特征数\n",
    "        #########################################\n",
    "        self.tree = self._grow_tree(X, y)  # 递归构建决策树\n",
    "\n",
    "    def predict(self, X):\n",
    "        return [self._predict(inputs) for inputs in X]  # 对每个样本进行预测\n",
    "\n",
    "    def _best_split(self, X, y):\n",
    "        \"\"\"寻找最优分割点\n",
    "        \n",
    "        输入:\n",
    "        X: numpy array, 样本特征\n",
    "        y: numpy array, 样本标签\n",
    "        \n",
    "        输出:\n",
    "        best_idx: int, 最优特征的索引\n",
    "        best_thr: float, 最优特征的分割阈值\n",
    "        \"\"\"\n",
    "        m = y.size\n",
    "        if m <= 1:\n",
    "            return None, None\n",
    "\n",
    "        #########################################\n",
    "        #\n",
    "        #  add your code\n",
    "        #\n",
    "        #########################################\n",
    "        num_samples_per_class =  # a numpy array to record the number of samples per class \n",
    "        best_gini = # current gini for the whole dataset\n",
    "\n",
    "\n",
    "        best_idx, best_thr = None, None\n",
    "        for idx in range(self.n_features_):\n",
    "            thresholds, classes = zip(*sorted(zip(X[:, idx], y)))\n",
    "            num_left = [0] * self.n_classes_\n",
    "            num_right = num_samples_per_class.copy()\n",
    "            \n",
    "            for i in range(1, m): # Traverse all possible partitions\n",
    "\n",
    "                ## split the dataset into two sets:\n",
    "                c = classes[i - 1]\n",
    "                num_left[c] += 1\n",
    "                num_right[c] -= 1\n",
    "\n",
    "                #########################################\n",
    "                #\n",
    "                #  add your code\n",
    "                #\n",
    "                #########################################\n",
    "                gini_left = # a scalar, hint: num_left[x] / i for x in all classes \n",
    "\n",
    "                gini_right = # a scalar, hint: num_right[c] / (m - i) for x in all classes \n",
    "               \n",
    "                gini = (i * gini_left + (m - i) * gini_right) / m\n",
    "                if thresholds[i] == thresholds[i - 1]:\n",
    "                    continue\n",
    "                if gini < best_gini:\n",
    "                    best_gini = gini\n",
    "                    best_idx = idx\n",
    "                    best_thr = (thresholds[i] + thresholds[i - 1]) / 2\n",
    "        return best_idx, best_thr\n",
    "\n",
    "    def _grow_tree(self, X, y, depth=0):\n",
    "        \"\"\"递归生成决策树\n",
    "        \n",
    "        输入:\n",
    "        X: numpy array, 样本特征 \n",
    "        y: numpy array, 样本标签\n",
    "        depth: int, 当前树的深度\n",
    "        \n",
    "        输出:\n",
    "        node: 决策树节点\n",
    "        \"\"\"\n",
    "        #########################################\n",
    "        #\n",
    "        #  add your code\n",
    "        #\n",
    "        #########################################\n",
    "        num_samples_per_class = # \n",
    "        predicted_class = # the most samples belong to this class, hint: np.argmax\n",
    "        node = Node(\n",
    "            num_samples=y.size,\n",
    "            num_samples_per_class=num_samples_per_class,\n",
    "            predicted_class=predicted_class,\n",
    "        )\n",
    "        if depth < self.max_depth:  \n",
    "            idx, thr = self._best_split(X, y)\n",
    "            if idx is not None:\n",
    "                indices_left = X[:, idx] < thr\n",
    "                X_left, y_left = X[indices_left], y[indices_left]\n",
    "                X_right, y_right = X[~indices_left], y[~indices_left]\n",
    "                node.feature_index = idx\n",
    "                node.threshold = thr\n",
    "\n",
    "        #########################################\n",
    "        #\n",
    "        #  add your code\n",
    "        #\n",
    "        #########################################\n",
    "                node.left = # Recursively construct the left sub-tree\n",
    "                node.right = # Recursively construct the right sub-tree\n",
    "\n",
    "        return node\n",
    "\n",
    "    def _predict(self, inputs):\n",
    "        node = self.tree\n",
    "        while node.left:\n",
    "            if inputs[node.feature_index] < node.threshold:\n",
    "                node = node.left\n",
    "            else:\n",
    "                node = node.right\n",
    "        return node.predicted_class\n",
    "\n",
    "class Node:\n",
    "    def __init__(self, num_samples, num_samples_per_class, predicted_class):\n",
    "        self.num_samples = num_samples  # 该节点的样本数\n",
    "        self.num_samples_per_class = num_samples_per_class  # 该节点每个类别的样本数\n",
    "        self.predicted_class = predicted_class  # 该节点预测的类别\n",
    "        self.feature_index = 0  # 分割特征的索引\n",
    "        self.threshold = 0  # 分割特征的阈值\n",
    "        self.left = None  # 左子树\n",
    "        self.right = None  # 右子树"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9e28761",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 载入数据集并划分\n",
    "iris_data = datasets.load_iris()\n",
    "x_train,x_test,y_train,y_test = train_test_split(iris_data.data,iris_data.target,test_size = 0.30,random_state = 20)\n",
    "# x_train    划分出的训练集数据（返回值）\n",
    "# x_test    划分出的测试集数据（返回值）\n",
    "# y_train    划分出的训练集标签（返回值）\n",
    "# y_test    划分出的测试集标签（返回值\n",
    "#花萼长度、花萼宽度、花瓣长度、花瓣宽度（花萼宽度、花瓣长度）\n",
    "\n",
    "#目标属性 三种类别山鸢尾花、杂色鸢尾花、维吉尼亚鸢尾花\n",
    "label_list = ['mountainIris', 'variegatedIris', 'virginiaIris']\n",
    "\n",
    "# 训练决策树分类器\n",
    "clf = DecisionTree(max_depth=3)\n",
    "clf.fit(x_train, y_train)\n",
    "\n",
    "# 在测试集上预测\n",
    "y_pred = clf.predict(x_test) \n",
    "\n",
    "# 计算准确率\n",
    "print(\"Accuracy:\",accuracy_score(y_test, y_pred))\n",
    "\n",
    "# 打印分类报告\n",
    "print(classification_report(y_test, y_pred, target_names=label_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54d7c7a7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
