DEV Community

Cover image for ML Intro : IRIS DATASET
Anand
Anand

Posted on

ML Intro : IRIS DATASET

Hello World of Machine Learning: Iris Dataset ๐Ÿ๐Ÿค–

In machine learning, the Iris dataset is often considered a "hello world" example. It's a classic dataset that is widely used for demonstration and testing purposes. In this tutorial, we'll explore the Iris dataset, load it from scikit-learn (sklearn), visualize the data, train a machine learning model, and evaluate its performance.

Image

Prerequisites

Before proceeding, make sure you have the following libraries installed:

  • NumPy: NumPy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a collection of mathematical functions to operate on these arrays.

  • Pandas: Pandas is a data manipulation and analysis library for Python, providing easy-to-use data structures and functions to work with structured data.

  • Matplotlib: Matplotlib is a plotting library for Python, which produces publication-quality figures in a variety of formats and interactive environments across platforms.

  • Scikit-learn: Scikit-learn is a machine learning library for Python, offering various tools for data mining and data analysis, built on NumPy, SciPy, and matplotlib.

You can install these libraries using pip:

pip install numpy pandas matplotlib scikit-learn
Enter fullscreen mode Exit fullscreen mode

Importing the Iris Dataset

The Iris dataset is included in the scikit-learn library, so we can import it directly without downloading any external files. Let's start by importing the necessary libraries and loading the dataset:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pandas as pd
import matplotlib.pyplot as plt
Enter fullscreen mode Exit fullscreen mode
# Load the Iris dataset
iris = load_iris()
Enter fullscreen mode Exit fullscreen mode

The load_iris() function from the sklearn.datasets module loads the Iris dataset. This dataset contains information about three different species of Iris flowers: Setosa, Versicolor, and Virginica. Each data instance represents a single flower, and the features include sepal length, sepal width, petal length, and petal width (all measured in centimeters).

Understanding the Dataset

  • Let's explore the dataset and its components:
# Print the description of the dataset
print(iris.DESCR)
Enter fullscreen mode Exit fullscreen mode
# Print the feature names
print("Feature names:", iris.feature_names)
Enter fullscreen mode Exit fullscreen mode

output:

Feature names: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
Enter fullscreen mode Exit fullscreen mode
# Print the target names (species)
print("Target names:", iris.target_names)
Enter fullscreen mode Exit fullscreen mode

output:

Target names: ['setosa' 'versicolor' 'virginica']
Enter fullscreen mode Exit fullscreen mode
# Print the first few rows of the data
iris_data = pd.DataFrame(iris.data, columns=iris.feature_names)
print(iris_data.head())
Enter fullscreen mode Exit fullscreen mode

output:

|   | sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) |
|---|--------------------|-------------------|--------------------|-------------------|
| 0 | 5.1                | 3.5               | 1.4                | 0.2               |
| 1 | 4.9                | 3.0               | 1.4                | 0.2               |
| 2 | 4.7                | 3.2               | 1.3                | 0.2               |
| 3 | 4.6                | 3.1               | 1.5                | 0.2               |
| 4 | 5.0                | 3.6               | 1.4                | 0.2               |

Enter fullscreen mode Exit fullscreen mode

The dataset description provides information about the Iris dataset, including its origin and characteristics. The feature_names attribute lists the names of the input features, and the target_names attribute lists the names of the target classes (species). The dataset itself is stored in the data attribute, and we've converted it into a Pandas DataFrame for easier viewing.

# Create a scatter plot of petal length vs petal width
plt.scatter(iris_data['petal length (cm)'], iris_data['petal width (cm)'], c=iris.target, cmap='viridis')
plt.xlabel('Petal Length (cm)')
plt.ylabel('Petal Width (cm)')
plt.title('Iris Dataset: Petal Length vs Petal Width')
plt.colorbar(label='Species')
plt.show()
Enter fullscreen mode Exit fullscreen mode

plot:

Image

This code will generate a scatter plot where each data point is colored according to its species. The plot should reveal a clear separation between the Setosa species and the other two species (Versicolor and Virginica) based on petal length and petal width.

Training a Machine Learning Model

Now, let's train a machine learning model to classify the Iris species based on the input features. We'll use the K-Nearest Neighbors (KNN) algorithm from scikit-learn:

from sklearn.neighbors import KNeighborsClassifier

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)

# Create and train the KNN model
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
Enter fullscreen mode Exit fullscreen mode

Here, we've split the dataset into training and testing sets using the train_test_split function from scikit-learn. We've set aside 20% of the data for testing (test_size=0.2).
We then create an instance of the KNeighborsClassifier class with n_neighbors=5, which means the model will consider the 5 nearest neighbors when making predictions.
Finally, we train the model using the fit method, passing in the training data (X_train) and training labels (y_train).

Model Evaluation

After training the model, we can evaluate its performance on the testing set:

from sklearn.metrics import accuracy_score, classification_report

# Make predictions on the testing set
y_pred = knn.predict(X_test)

# Calculate the accuracy score
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

# Print the classification report
print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))
Enter fullscreen mode Exit fullscreen mode

The accuracy_score function from scikit-learn calculates the accuracy of the model's predictions by comparing the predicted labels (y_pred) with the true labels (y_test).
The classification_report function provides a more detailed evaluation, including precision, recall, and F1-score for each class.

Accuracy: 0.9666666666666667 
Enter fullscreen mode Exit fullscreen mode
Classification Report:

|              | precision | recall | f1-score | support |
|--------------|-----------|--------|----------|---------|
| setosa       | 1.00      | 1.00   | 1.00     | 11      |
| versicolor   | 0.90      | 0.94   | 0.92     | 9       |
| virginica    | 0.97      | 0.95   | 0.96     | 10      |
Enter fullscreen mode Exit fullscreen mode


| Metric        | Precision | Recall | F1-Score | Support |
|---------------|-----------|--------|----------|---------|
| Macro Average | 0.96      | 0.96   | 0.96     | 30      |
| Weighted Avg  | 0.96      | 0.97   | 0.96     | 30      |
Enter fullscreen mode Exit fullscreen mode

In this example, the KNN model achieved an accuracy of 96.67% on the testing set, which is quite good for the Iris dataset. The classification report shows that the model performed well across all three classes, with high precision, recall, and F1-scores.

Conclusion
In this tutorial, we explored the classic Iris dataset, loaded it from scikit-learn, visualized the data, trained a KNN machine learning model, and evaluated its performance. This exercise serves as a great introduction to the world of machine learning, covering essential steps like data exploration, model training, and model evaluation.

Feel free to experiment further with this dataset: try different machine learning algorithms or explore additional evaluation metrics. Happy coding!

About Me:
๐Ÿ–‡๏ธLinkedIn
๐Ÿง‘โ€๐Ÿ’ปGitHub

Top comments (0)