自定义数据集 使用scikit-learn中svm的包实现svm分类

news/2025/2/3 23:15:07 标签: python, 开发语言

代码:

import numpy as np  # 导入用于数值计算的库
import matplotlib.pyplot as plt  # 导入用于绘图的库

# class1_points 和 class2_points 分别定义了两个类别的数据点,二维坐标
class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

# 将 class1 和 class2 的 x 和 y 坐标合并为一个数据集
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]))  # 所有的 x1 坐标
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]))  # 所有的 x2 坐标

# 创建标签,class1 点标记为 1,class2 点标记为 -1
y = np.concatenate((np.ones(class1_points.shape[0]), -np.ones(class2_points.shape[0])))

# 初始化超平面的参数 w1, w2 和偏置 b
w1 = 0.1
w2 = 0.1
b = 0

# 学习率
learning_rate = 0.05

# 数据集的大小
l_data = x1_data.size

# 创建图形和子图
fig, (ax1, ax2) = plt.subplots(2, 1)

# 初始化存储每一步的损失值和迭代步数
step_list = np.array([])
loss_values = np.array([])

# 设定迭代次数,控制模型训练的周期
num_iterations = 1000
for n in range(1, num_iterations + 1):
    # 计算超平面预测值:z = w1 * x1 + w2 * x2 + b
    z = w1 * x1_data + w2 * x2_data + b

    # 计算每个点的损失:Hinge loss
    yz = y * z  # 预测值与真实标签的乘积
    loss = 1 - yz  # hinge loss 为 1 - yz,当 yz > 1 时,损失为 0
    loss[loss < 0] = 0  # 如果损失小于 0,置为 0
    hinge_loss = np.mean(loss)  # 计算平均损失(取所有数据点的损失均值)
    loss_values = np.append(loss_values, hinge_loss)  # 保存当前步的损失值
    step_list = np.append(step_list, n)  # 保存当前迭代步数

    # 初始化梯度
    gradient_w1 = 0
    gradient_w2 = 0
    gradient_b = 0

    # 梯度下降法计算梯度
    for i in range(len(y)):
        if loss[i] > 0:  # 仅考虑损失大于 0 的点
            gradient_w1 += -y[i] * x1_data[i]
            gradient_w2 += -y[i] * x2_data[i]
            gradient_b += -y[i]

    # 平均化梯度
    gradient_w1 /= len(y)
    gradient_w2 /= len(y)
    gradient_b /= len(y)

    # 更新超平面参数:w1, w2, b
    w1 -= learning_rate * gradient_w1
    w2 -= learning_rate * gradient_w2
    b -= learning_rate * gradient_b

    # 每 50 步或第一次迭代时,绘制一次更新图
    frequence_display = 50
    if n % frequence_display == 0 or n == 1:
        if np.abs(w2) < 1e-5:  # 避免 w2 太小导致无法计算
            continue

        # 计算超平面的直线方程,用于绘制超平面
        x1_min, x1_max = 0, 6  # x1 的范围
        x2_min, x2_max = -(w1 * x1_min + b) / w2, -(w1 * x1_max + b) / w2  # x2 的值,基于超平面方程计算

        # 清除上一轮绘制的图像,绘制新的图
        ax1.clear()
        ax1.scatter(x1_data[:len(class1_points)], x2_data[:len(class1_points)], c='red', label='Class 1')  # class1 红色
        ax1.scatter(x1_data[len(class1_points):], x2_data[len(class1_points):], c='blue', label='Class 2')  # class2 蓝色
        ax1.plot([x1_min, x1_max], [x2_min, x2_max], 'r-')  # 绘制超平面
        ax1.set_title(f"SVM: w1={round(w1.item(), 3)}, w2={round(w2.item(), 3)}, b={round(b.item(), 3)}")

        # 绘制损失函数的变化图
        ax2.clear()
        ax2.plot(step_list, loss_values, 'g-')  # 损失图,绿色线
        ax2.set_xlabel("Step")  # x 轴为步数
        ax2.set_ylabel("Loss")  # y 轴为损失值

        # 每次绘图后暂停 1 秒,展示图像
        plt.pause(1)

# 显示最终图形
plt.show()

结果:


http://www.niftyadmin.cn/n/5841107.html

相关文章

影视文件大数据高速分发方案

在当今的数字时代&#xff0c;影视行业的内容创作和传播方式经历了翻天覆地的变化。随着4K、8K高清视频的普及&#xff0c;以及虚拟现实(VR)和增强现实(AR)技术的发展&#xff0c;影视文件的数据量正以前所未有的速度增长。这就要求行业内的参与者必须拥有高效的大数据传输解决…

[Linux]如何將腳本(shell script)轉換到系統管理服務器(systemd service)來運行?

[InfluxDB]Monitor Tem. and Volt of RaspberryPi and Send Message by Line Notify 在Linux中&#xff0c;shell腳本(shell script)常用於運行各種自動化的流程&#xff0c;包含API串接&#xff0c;設置和啟動應用服務等等&#xff0c;腳本語法也相對易學易讀&#xff0c;因此…

Java泛型深度解析(JDK23)

第一章 泛型革命 1.1 类型安全的进化史 前泛型时代的类型转换隐患 代码的血泪史&#xff08;Java 1.4版示例&#xff09;&#xff1a; List rawList new ArrayList(); rawList.add("Java"); rawList.add(Integer.valueOf(42)); // 编译通过// 灾难在运行时爆发…

为AI聊天工具添加一个知识系统 之75 详细设计之16 正则表达式 之3 正则表达式模板

本文要点 概念图式schema&#xff1a;。处理“我” 立“每一个新提概念的提出都首先是语言的-含糊概念 Notion{ Yes&#xff0c;Unkown,No}&#xff0c;然后才是程序的-模糊符号Notation {True&#xff0c;False}&#xff0c;最后会是数据的-近似值 Approximation{Good,Fair,…

hive为什么建表,表存储什么

‌Hive建表的主要目的是为了方便管理和查询存储在Hadoop分布式文件系统&#xff08;HDFS&#xff09;上的大规模数据。‌ Hive作为一个构建在Hadoop之上的数据仓库工具&#xff0c;主要功能是提供类似SQL的查询语言HiveQL来处理和分析存储在HDFS中的数据。通过建表&#xff0c;…

JavaScript 中的 CSS 与页面响应式设计

JavaScript 中的 CSS 与页面响应式设计 JavaScript 中的 CSS 与页面响应式设计1. 引言2. JavaScript 与 CSS 的基本概念2.1 CSS 的作用2.2 JavaScript 的作用3. 动态控制样式:JavaScript 修改 CSS 的方法3.1 使用 `document.styleSheets` API3.2 使用 `classList` 修改类3.3 使…

Java创建对象有几种方式?

大家好&#xff0c;我是锋哥。今天分享关于【Java创建对象有几种方式?】面试题。希望对大家有帮助&#xff1b; Java创建对象有几种方式? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在 Java 中&#xff0c;创建对象有几种常见的方式&#xff0c;具体如下&…

JVM运行时数据区域-附面试题

Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域。这些区域 有各自的用途&#xff0c;以及创建和销毁的时间&#xff0c;有的区域随着虚拟机进程的启动而一直存在&#xff0c;有些区域则是 依赖用户线程的启动和结束而建立和销毁。 1. 程序计…