跳转至

Design

约 876 个字 319 行代码 预计阅读时间 7 分钟

牛顿迭代法

\[ x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)} \]

今天的目标就是做一个牛顿法的求解器。

类的抽象

如果我们想求 \(\sqrt{2}\),那代码很简单:

C++
#include <iostream>
#include <cmath>

int main()
{
    double a = 2;
    double tolerance = 1e-12;
    int max_iterations = 30;

    int k = 0;
    double x = 1.0;
    std::cout << "k = " << k << ", x = " << x << std::endl;

    while ((std::fabs(x * x - a) > tolerance) && (k++ < max_iterations))
    {
        x = x - (x * x - a) / (2 * x);
        std::cout << "k = " << k << ", x = " << x << std::endl;
    }
}

但是如果想把代码做的更通用一点,那么我们就要把代码做的更抽象一点。

C++
#include <iostream>
#include <cmath>

class NewtonSolver
{
private:
    double a;
    double tolerance;
    int max_iterations;

    int k;
    double x;

public:
    NewtonSolver(double a = 612.0, double tolerance = 1e-12, int max_iterations = 30) : a(a), tolerance(tolerance), max_iterations(max_iterations) {}

    void printInfo()
    {
        std::cout << "k = " << k << ", x = " << x << std::endl;
    }

    double f(double x)
    {
        return x * x - a;
    }

    double df(double x)
    {
        return 2 * x;
    }

    bool is_close(double x) 
    {
        return (std::fabs(f(x)) < tolerance);
    }

    void improve(double x0)
    {
        k = 0;
        x = x0;
        printInfo();

        while (!is_close(x) && (k < max_iterations))
        {
            k++;
            x = x - ((f(x)) / df(x));
            printInfo();
        }
    }
};

int main()
{
    NewtonSolver solver;
    solver.improve(1.0);
}

但是我们发现上边的代码有一点不好,我们的牛顿求解器需要满足更普遍的问题,但是上边的 \(f(x)\)\(f'(x)\) 是死的,不能被修改。

所以我们需要把 \(f(x)\)\(f'(x)\) 拿出来,单独定义成一个类,这个类继承自 NewtonSolver,然后我们就可以在这个类里定义 \(f(x)\)\(f'(x)\) 了。

C++
class NewtonSolver
{
private:
    virtual double f(double x) = 0;
    virtual double df(double x) = 0;

    double tolerance = 1e-12;
    int max_iterations = 30;
    int k;
    double x;

    void printInfo()
    {
        std::cout << "k = " << k << ", x = " << x << ", f(x) = " << f(x) << std::endl;
    }

    bool is_close(double x)
    {
        return (std::fabs(f(x)) < tolerance);
    }

public:
    NewtonSolver() : k(0), x(1) {}

    void improve(double x0)
    {
        k = 0;
        x = x0;
        printInfo();

        while (!is_close(x) && (k < max_iterations))
        {
            k++;
            x = x - ((f(x)) / df(x));
            printInfo();
        }
    }
};

class SqrtSolver : public NewtonSolver
{
private:
    double a;

public:
    SqrtSolver(double a) : a(a) {}

private:
    double f(double x) override
    {
        return x * x - a;
    }
    double df(double x) override
    {
        return 2 * x;
    }
};

这样我们只用提供 \(f(x)\)\(f'(x)\) 就可以很方便地使用牛顿迭代法了。

我们还可以自己写一个求解 \(n\) 次方的求解器:

C++
class NthRootSolver : public NewtonSolver
{
private:
    int n;
    double a;

public:
    NthRootSolver(int n, double a) : n(n), a(a) {}

    double f(double x) override
    {
        return std::pow(x, n) - a;
    }

    double df(double x) override
    {
        return n * pow(x, n - 1);
    }
};

另一种抽象方法是函数式编程。

函数式编程

函数式编程中,函数可以被当作参数传递,或者被当作返回值。

在函数式编程中没有类的概念,一切工作都是函数来完成的。

C++
#include <iostream>
#include <cmath>
#include <functional>

using fn = std::function<double(double)>;

void printInfo(int k, double x, double fx)
{
    std::cout << "k = " << k << ", x = " << x << ", f(x) = " << fx << std::endl;
}

bool is_close(double fx, double tolerance)
{
    return (std::fabs(fx) < tolerance);
}

double newton_solver(fn f, fn df, double x0, double tolerance = 1e-12, int max_iterations = 30)
{
    int k = 0;
    double x = x0;
    printInfo(k, x, f(x));

    while (!is_close(f(x), tolerance) && (k < max_iterations))
    {
        k++;
        x = x - (f(x) / df(x));
        printInfo(k, x, f(x));
    }
    return x;
}

double sqrt_newton(double a, double x0 = 1.0)
{
    auto f = [a](double x)
    { return x * x - a; };
    auto df = [](double x)
    { return x * 2; };

    return newton_solver(f, df, x0);
}

int main()
{
    sqrt_newton(64);
}

可以看到这里的全部工作都是通过函数来完成的,需要注意的是,我们使用了 fn 来进行函数的传递。using fn = std::function<double(double)>; 代表它是一个函数,需要一个 double 类型的参数,返回一个 double 类型的值。

Note

注意到在 sqrt_newton 函数中出现了一类新的语句 auto f = [a](double x) { return x * x - a; };, 这个叫做 lambda 语句,中括号代表需要从上下文获取参数,这里就是需要获取 a, 传递给大括号中的表达式。

所以这里用户在使用的时候只需要给定 a, f(x), df(x), 就能够快速求解了,与之前的类一样都很方便。

另一个例子是求解 \(n\) 次方根。

C++
1
2
3
4
5
6
7
8
9
double NthRootSolver(double a, int n, double x0 = 1.0)
{
    auto f = [a, n](double x)
    { return (pow(x, n) - a); };
    auto df = [n](double x)
    { return n * pow(x, n - 1); };

    return newton_solver(f, df, x0);
}

求解 \(\cos{x} - x^3\)

C++
1
2
3
4
5
6
7
8
9
double RectangleSolver(double x0=1.0)
{
    auto f = [](double x)
    { return (cos(x) - pow(x, 3)); };
    auto df = [](double x)
    { return (-sin(x) - 3 * pow(x, 2)); };

    return newton_solver(f, df, x0);
}

然后我们想对这个功能进行封装,跟前几节课讲的一样,函数声明放在 .h 头文件中, 函数实现放在 .cpp 文件中。

newton_solver.h

C++
1
2
3
4
5
6
7
8
#include <functional>

namespace ns{

    using fn = std::function<double(double)>;
    double newton_solver(fn f, fn df, double x0, double tolerance = 1e-12, int max_iterations = 30);

}

newton_solver.cpp

C++
#include <cmath>
#include <iostream>
#include "newton_solver.h"

namespace ns
{
    void printInfo(int k, double x, double fx)
    {
        std::cout << "k = " << k << ", x = " << x << ", f(x) = " << fx << std::endl;
    }

    bool is_close(double fx, double tolerance)
    {
        return (std::fabs(fx) < tolerance);
    }

    double newton_solver(fn f, fn df, double x0, double tolerance, int max_iterations)
    {
        int k = 0;
        double x = x0;
        printInfo(k, x, f(x));

        while (!is_close(f(x), tolerance) && (k < max_iterations))
        {
            k++;
            x = x - (f(x) / df(x));
            printInfo(k, x, f(x));
        }
        return x;
    }
}

main.cpp

C++
#include <cmath>
#include "newton_solver.h"

double sqrt_newton(double a, double x0 = 1.0)
{
    auto f = [a](double x)
    { return x * x - a; };
    auto df = [](double x)
    { return x * 2; };

    return ns::newton_solver(f, df, x0);
}

double NthRootSolver(double a, int n, double x0 = 1.0)
{
    auto f = [a, n](double x)
    { return (pow(x, n) - a); };
    auto df = [n](double x)
    { return n * pow(x, n - 1); };

    return ns::newton_solver(f, df, x0);
}

double RectangleSolver(double x0 = 1.0)
{
    auto f = [](double x)
    { return (cos(x) - pow(x, 3)); };
    auto df = [](double x)
    { return (-sin(x) - 3 * pow(x, 2)); };

    return ns::newton_solver(f, df, x0);
}

int main()
{
    sqrt_newton(64);

    NthRootSolver(64, 2);

    RectangleSolver();
}

截断

如果

C++
1
2
3
Base b;
Derived d;
b = d;

会发生什么呢?

我们可以写代码检验一下

C++
#include <iostream>

using namespace std;

class Base
{
protected:
    int data;

public:
    Base() : data(10) {}
    virtual void bar()
    {
        cout << "Base::bar(): data = " << data << endl;
    }
};

class Derived : public Base
{
private:
    int datad;

public:
    Derived() : datad(100) { data = 7; }
    void bar() override
    {
        cout << "Derived::bar()" << ", data = " << data << ", datad = " << datad << endl;
    }
};

int main()
{
    Base b;
    b.bar();

    Derived d;
    Base *p = &d;
    p->bar();

    b = d;
    b.bar();
    p = &b;
    p->bar();
}

前两次的 bar() 函数的输出是正常的,但是当把 d 赋值给 b 的时候,我们发现两次 bar() 函数都是调用的 Base::bar() 函数,而不是 Derived::bar() 函数,而且 Base 中的 data 的值被改变成了 7.

这意味着在进行赋值的时候,是把派生类的数据成员的值赋值给了基类的数据成员。而同时出于安全考虑,并不会修改基类的虚函数表指针,所以在调用 bar() 函数的时候,还是调用的基类的 bar() 函数。

如果想修改虚函数指针的话,可以这么做:

C++
1
2
3
void **pb = (void **)&b;
void **pd = (void **)&d;
*pb = *pd;

然后再次调用 b.bar() 我们发现静态绑定的 b.bar() 调用的是 Base::bar() 函数,而动态绑定的 p->bar() 调用的是 Derived::bar() 函数。

Text Only
1
2
3
4
Base::bar(): data = 10
Derived::bar(), data = 7, datad = 100
Base::bar(): data = 7
Derived::bar(), data = 7, datad = 1

但是出现问题了:datad 的值变成了 1, 而不是原来的 100。这是因为我们这时访问到的是 Base 对应偏移量的地址,而 BaseDerived 小,所以我们并不知道我们访问到的是什么数据,内存里边是什么,打印出来就是什么。

Relaxation Example

C++
class Expr {
public:
    virtual Expr* newExpr();
    virtual Expr& clone();
    virtual Expr self();
};

class BinaryExpr : public Expr {
public:
    virtual BinaryExpr* newExpr();  // OK
    virtual BinaryExpr& clone();    // OK
    virtual BinaryExpr self();      // Error!
};

子类中的 newExpr()clone() 是可以正常工作的,他俩与基类构成了覆写关系,因为返回的都是同一个类型。

self() 函数返回的是一个值,并不能构成覆写关系,所以会报错。

评论