Sequential Minimal Optimization

序列最小最优化(sequential minimal optimization, SMO)算法是支持向量机的学习算法,本文主要讲解了SMO算法,并且实现了一个简单基于SMO的SVM demo,最后简要了解封装SVM的scikit-learn库。

1. SVM简要回顾

SVM最终会得到一个对偶问题:

dual problem

解决这个优化问题可以得到a;进而,

SVM的学习,就是通过训练数据计算出a和b,然后通过决策函数判定xj的分类。其中a是一个向量,长度与训练数据的样本数相同,如果训练数据很大,那么这个向量会很长,不过绝大部分的分量值都是0,只有支持向量的对应的分量值大于0。而SMO算法就是用来解决SVM的对偶问题。

2. SMO

简介:SMO是一种启发式算法,其基本思想是:如果所有变量的解都满足了此最优化问题的KKT条件,那么这个最优化问题的解就得到了。否则,选择两个变量,固定其它变量,针对这两个变量构建一个二次规划问题,然后关于这个二次规划的问题的解就更接近原始的二次归还问题的解,因为这个解使得需要优化的问题的函数值更小。


2.1 KKT Conditions

SVM KKT Conditions

引入一个误差系数ε(tolerance),

2.2 SMO Algorithm

SMO Algorithm

两个关键点:变量如何选取以及如何更新

变量选取:

变量更新:

_More details in Platt’s paper:Sequential Minimal Optimization:
A Fast Algorithm for Training Support Vector Machines
_

3. SMO算法实现(python)

  • 只是简单版本,并未实现复杂的变量选取规则,简化如下:一次迭代中,遍历所有的ai,如果ai违反了KKT条件,那么就将它做为第一个变量,然后再遍历所有的ai,依次做为第二个变量,如果第二个变量有足够的下降,那么就更新两个变量。如果没有,就不更新。
  • 线性核
  • 原始数据:data
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    #!/usr/bin/python
    # -*- coding:utf8 -*-
    # Created by Helic on 2017/8/31
    # 参考 http://liuhongjiang.github.io/tech/blog/2012/12/28/svm-smo/

    import sys
    import math
    import matplotlib.pyplot as plt

    samples = []
    labels = []
    class svm_params:
    def __init__(self):
    self.a = []
    self.b = 0

    params = svm_params()
    e_dict = []

    train_data = "svm.train"

    def loaddata():
    fn = open(train_data, "r")
    for line in fn:
    line = line[:-1]
    vlist = line.split("\t")
    samples.append((int(vlist[0]), int(vlist[1])))
    labels.append(int(vlist[2]))
    params.a.append(0.0)
    fn.close()


    # 线性核
    def kernel(j, k):
    ret = 0.0
    for idx in range(len(samples[j])):
    ret += samples[j][idx] * samples[k][idx] # 计算内积
    return ret


    def predict_real_diff(i):
    """return Ei=g(xi)−yi=WXi+b-yi,返回KKT条件中的Ei变量"""
    diff = 0.0
    for j in range(len(samples)):
    diff += params.a[j] * labels[j] * kernel(j, i)
    diff = diff + params.b - labels[i]
    return diff

    def init_e_dict():
    for i in range(len(params.a)):
    e_dict.append(predict_real_diff(i))

    def update_e_dict():
    for i in range(len(params.a)):
    e_dict[i] = predict_real_diff(i)


    def train(tolerance, times, C):
    time = 0
    init_e_dict()
    updated = True
    while time < times and updated:
    updated = False
    time += 1
    for i in range(len(params.a)):
    ai = params.a[i]
    Ei = e_dict[i]
    # 违反KKT
    # agaist the KKT
    if (labels[i] * Ei < -tolerance and ai < C) or (labels[i] * Ei > tolerance and ai > 0):
    for j in range(len(params.a)):
    if j == i: continue
    eta = kernel(i, i) + kernel(j, j) - 2 * kernel(i, j)
    if eta <= 0:
    continue
    new_aj = params.a[j] + labels[j] * (e_dict[i] - e_dict[j]) / eta
    L = 0.0
    H = 0.0
    if labels[i] == labels[j]:
    L = max(0, params.a[j] + params.a[i] - C)
    H = min(C, params.a[j] + params.a[i])
    else:
    L = max(0, params.a[j] - params.a[i])
    H = min(C, C + params.a[j] - params.a[i])
    if new_aj > H:
    new_aj = H
    if new_aj < L:
    new_aj = L
    # 《统计学习方法》公式7.109(下同)
    # formula 7.109
    new_ai = params.a[i] + labels[i] * labels[j] * (params.a[j] - new_aj)

    # 第二个变量下降是否达到最小步长
    # decline enough for new_aj
    if abs(params.a[j] - new_aj) < 0.001:
    print "j = %d, is not moving enough" % j
    continue

    # formula 7.115
    new_b1 = params.b - e_dict[i] - labels[i]*kernel(i,i)*(new_ai-params.a[i]) - labels[j]*kernel(j,i)*(new_aj-params.a[j])
    # formula 7.116
    new_b2 = params.b - e_dict[j] - labels[i]*kernel(i,j)*(new_ai-params.a[i]) - labels[j]*kernel(j,j)*(new_aj-params.a[j])
    if new_ai > 0 and new_ai < C: new_b = new_b1
    elif new_aj > 0 and new_aj < C: new_b = new_b2
    else: new_b = (new_b1 + new_b2) / 2.0

    params.a[i] = new_ai
    params.a[j] = new_aj
    params.b = new_b
    update_e_dict()
    updated = True
    print "iterate: %d, changepair: i: %d, j:%d" %(time, i, j)

    def draw(tolerance, C):
    plt.xlabel(u"x1")
    plt.xlim(0, 100)
    plt.ylabel(u"x2")
    plt.ylim(0, 100)
    plt.title("SVM - %s, tolerance %f, C %f" % (train_data, tolerance, C))
    ftrain = open(train_data, "r")
    for line in ftrain:
    line = line[:-1]
    sam = line.split("\t")
    if int(sam[2]) > 0:
    plt.plot(sam[0],sam[1], 'or')
    else:
    plt.plot(sam[0],sam[1], 'og')

    w1 = 0.0
    w2 = 0.0
    for i in range(len(labels)):
    w1 += params.a[i] * labels[i] * samples[i][0]
    w2 += params.a[i] * labels[i] * samples[i][1]
    w = - w1 / w2

    b = - params.b / w2
    r = 1 / w2

    lp_x1 = [10, 90]
    lp_x2 = []
    lp_x2up = []
    lp_x2down = []
    for x1 in lp_x1:
    lp_x2.append(w * x1 + b)
    lp_x2up.append(w * x1 + b + r)
    lp_x2down.append(w * x1 + b - r)
    plt.plot(lp_x1, lp_x2, 'b')
    plt.plot(lp_x1, lp_x2up, 'b--')
    plt.plot(lp_x1, lp_x2down, 'b--')
    plt.show()

    if __name__ == "__main__":
    loaddata()
    print samples
    print labels
    # 惩罚系数
    # penalty for mis classify
    C = 10
    # 计算精度
    # computational accuracy
    tolerance = 0.0001
    train(tolerance, 100, C)
    print "a = ", params.a
    print "b = ", params.b
    support = []
    for i in range(len(params.a)):
    if params.a[i] > 0 and params.a[i] < C:
    support.append(samples[i])
    print "support vector = ", support
    draw(tolerance, C)

运行结果:

4. 封装SVM的scikit-learn库

在实际使用中,并不需要自己亲手书写SVM算法实现,scikit-learn SVM算法库封装了libsvm 和 liblinear 的实现,仅仅重写了算法了接口部分。


scikit-learn中SVM的算法库分为两类,一类是分类的算法库,包括SVC, NuSVC,和LinearSVC 3个类。另一类是回归算法库,包括SVR, NuSVR,和LinearSVR 3个类。相关的类都包裹在sklearn.svm模块之中。

对于SVC, NuSVC,和LinearSVC 3个分类的类,SVC和 NuSVC差不多,区别仅仅在于对损失的度量方式不同,而LinearSVC从名字就可以看出,他是线性分类,也就是不支持各种低维到高维的核函数,仅仅支持线性核函数,对线性不可分的数据不能使用。

同样的,对于SVR, NuSVR,和LinearSVR 3个回归的类, SVR和NuSVR差不多,区别也仅仅在于对损失的度量方式不同。LinearSVR是线性回归,只能使用线性核函数。

_More details in http://www.cnblogs.com/pinard/p/6117515.html_

5. Summary

SVM是一种相当高效的算法,在手写体数字识别取得不错的效果,适用范围很广,可以解决高维特征;更重要的一点是,kernel并非SVM所独有,kernel的本质是将低维映射到高维,在一些其他算法中也可以应用,如感知器算法。