{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "f6fAdAPQIjip" }, "outputs": [], "source": [ "from sklearn.datasets import fetch_openml" ] }, { "cell_type": "code", "source": [ "mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n", "X = mnist.data\n", "y = mnist.target" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "oAm24tpMIp73", "outputId": "c839d11d-e789-401b-eae3-2e063aba2ca5" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/sklearn/datasets/_openml.py:968: FutureWarning: The default value of `parser` will change from `'liac-arff'` to `'auto'` in 1.4. You can set `parser='auto'` to silence this warning. Therefore, an `ImportError` will be raised from 1.4 if the dataset is dense and pandas is not installed. Note that the pandas parser may return different data types. See the Notes Section in fetch_openml's API doc for details.\n", " warn(\n" ] } ] }, { "cell_type": "code", "source": [ "from sklearn.model_selection import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=50000, random_state=42)\n", "X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, train_size=10000, random_state=42)" ], "metadata": { "id": "a3H37zSsIr9i" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "from sklearn.metrics import accuracy_score\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "\n", "rnd_clf = RandomForestClassifier(n_estimators=500, max_leaf_nodes=32, random_state=42)\n", "rnd_clf.fit(X_train, y_train)\n", "\n", "\n", "y_val_pred = rnd_clf.predict(X_val)\n", "accuracy_score(y_val, y_val_pred)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FGLfzRfaIxEp", "outputId": "f7d676f6-4a9a-46c0-de00-a6fd55be3e78" }, "execution_count": 4, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.8694" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "code", "source": [ "from sklearn.ensemble import ExtraTreesClassifier\n", "ext_clf = ExtraTreesClassifier(n_estimators=500, max_leaf_nodes=32, random_state=42)\n", "ext_clf.fit(X_train, y_train)\n", "y_val_pred = ext_clf.predict(X_val)\n", "accuracy_score(y_val, y_val_pred)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Az6baZyaI1gf", "outputId": "71199c7e-12bb-4e3e-c4ed-ea6a3818f09b" }, "execution_count": 5, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.8596" ] }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "code", "source": [ "from sklearn.svm import LinearSVC\n", "svm_clf = LinearSVC(max_iter=100, tol=20, random_state=42)\n", "svm_clf.fit(X_train, y_train)\n", "y_val_pred = svm_clf.predict(X_val)\n", "accuracy_score(y_val, y_val_pred)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gsduLLGVI4hm", "outputId": "b2799987-d3f3-4dc7-80c5-ebef8120d4ad" }, "execution_count": 6, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.8777" ] }, "metadata": {}, "execution_count": 6 } ] }, { "cell_type": "code", "source": [ "from sklearn.ensemble import VotingClassifier\n", "voting_clf = VotingClassifier(\n", "estimators=[('rf', rnd_clf), ('et', ext_clf), ('svc', svm_clf)]\n", ")\n", "voting_clf.fit(X_train, y_train)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 109 }, "id": "mrP1ENWOI6_x", "outputId": "9d898938-46d2-4d92-f263-1bc841991b24" }, "execution_count": 7, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "VotingClassifier(estimators=[('rf',\n", " RandomForestClassifier(max_leaf_nodes=32,\n", " n_estimators=500,\n", " random_state=42)),\n", " ('et',\n", " ExtraTreesClassifier(max_leaf_nodes=32,\n", " n_estimators=500,\n", " random_state=42)),\n", " ('svc',\n", " LinearSVC(max_iter=100, random_state=42,\n", " tol=20))])" ], "text/html": [ "
VotingClassifier(estimators=[('rf',\n", " RandomForestClassifier(max_leaf_nodes=32,\n", " n_estimators=500,\n", " random_state=42)),\n", " ('et',\n", " ExtraTreesClassifier(max_leaf_nodes=32,\n", " n_estimators=500,\n", " random_state=42)),\n", " ('svc',\n", " LinearSVC(max_iter=100, random_state=42,\n", " tol=20))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
VotingClassifier(estimators=[('rf',\n", " RandomForestClassifier(max_leaf_nodes=32,\n", " n_estimators=500,\n", " random_state=42)),\n", " ('et',\n", " ExtraTreesClassifier(max_leaf_nodes=32,\n", " n_estimators=500,\n", " random_state=42)),\n", " ('svc',\n", " LinearSVC(max_iter=100, random_state=42,\n", " tol=20))])
RandomForestClassifier(max_leaf_nodes=32, n_estimators=500, random_state=42)
ExtraTreesClassifier(max_leaf_nodes=32, n_estimators=500, random_state=42)
LinearSVC(max_iter=100, random_state=42, tol=20)