Experiment of Different Basis Functions

center

  • 我们生成了100x100个网格样本点,每个样本点的坐标为 最终得到的data如上图最左侧的Original (3D)所示
  • 我们使用不同的基函数来拟合这个函数,分别是poly_basis和fourier_basis
    • poly_basis:最高的degree p=4, 一共有20项,
    • fourier_basis:最高的degree p=4, 一共有20项,
    • fourierq_basis:这里加了个Q其实只是对项的数目控制不一样,q=4,一共有25项,
  • 最终的拟合结果(3D surface和2D heatmap)如上图所示,其中
    • 第一列是原始的函数
    • 第二列是poly_basis拟合的结果
    • 第三列是fourier_basis拟合的结果
    • 第四列是fourierq_basis拟合的结果
  • 使用numpy的least squeare方法来拟合函数,最终的MSE如下:
    • Polynomial basis MSE: 0.114360
    • Fourier basis MSE: 0.102145
    • Fourier Q basis MSE: 0.031606
  • 从拟合的结果来看,fourierq_basis > fourier_basis > poly_basis

Problem

承接之前的VF state value,我们希望有一个特征函数可以比较好的描述状态s,从而可以比较好地将该函数的输出向量线性组合成最终的state value:

这里的主要目的就是研究的表现形式。

该部分主要参考了Sutton书中9.5.1和9.5.2节的内容。

Polynomial basis

Polynomial Basis Definition(From Sutton's book)

Suppose we have state , with each . For a k-dimensional state space, each order-n polynomial-basis feature can be written as

where each is an integer in the set for an integer . These features make up the order-n polynomial basis for dimension k, which contains different features.

For example, we can consider a polynomial basis:

And then we can fit the function using the polynomial basis:

where .

由于后面的阶数会越来越高,如果不限制非常不稳定,所以我们依然会限制

在我们之前的任务中,在grid_world/result.md中的较好结果是使用了,即

结果如下:

True State Values 3DTD-Linear(poly-3) Estimated State Values 3D

Fourier basis

Fourier Basis Definition(From Sutton's book)

Suppose each state s corresponds to a vector of k numbers, , with each . The th feature in the order- Fourier cosine basis can then be written

where , with for and . This defines a feature for each of the possible integer vectors ci. The inner product has the effect of assigning an integer in to each dimension of . This integer determines the feature’s frequency along that dimension. The features can of course be shifted and scaled to suit the bounded state space of a particular application.

For example, our , then , let , there possible values:

这里记得做一下约束

在我们之前的任务中,在grid_world/result.md中的较好结果也是使用了,即

结果如下:

True State Values 3DTD-Linear(fourier-3) Estimated State Values 3D

Appendix

复现Experiment of Different Basis Functions的结果代码如下(注意这里的代码生成的数据就符合,如果是网格数据的话可以简单将纵横坐标轴的值除以行数和列数,保证为1就可以学到东西了):

python feature.py
# File: feature.py
import numpy as np
from matplotlib.colors import Normalize
 
def pos2poly(state, p=2):
    r"""
    state: (x,y) or [(x1,y1), (x2,y2), ...]
    p: the maximum degree of the polynomial
    
    return: return: poly_basis: [N, (p+1)(p+2)/2]
    - [1, x, y, x^2, xy, y^2, ...]
    """
    if not isinstance(state, np.ndarray):
        state = np.array(state)
    if state.ndim == 1:
        state = state.reshape(1, -1)
    assert state.shape[-1] == 2
    
    n_samples = state.shape[0]
    x = state[:, 0]
    y = state[:, 1]
    
    features = [np.ones(n_samples)]
    
    for i in range(1, p + 1):
        for j in range(i + 1):
            # i \in [1, p] j \in [0, i]
            # {x^(i-j) * y^j}
            features.append((x ** (i - j)) * (y ** j))
            
    return np.column_stack(features)
 
def pos2fourier(state, p=2):
    r"""
    state: (x,y) or [(x1,y1), (x2,y2), ...]
    p: the maximum degree of the fourier series
    
    return: fourier_basis: [N, (p+1)(p+2)/2]
    - [1, cos(\pi x), cos(\pi y), cos(2\pi x), cos(\pi (x+y)), cos(2\pi y), ...]
    """
    if not isinstance(state, np.ndarray):
        state = np.array(state)
    if state.ndim == 1:
        state = state.reshape(1, -1)
    assert state.shape[-1] == 2
    
    n_samples = state.shape[0]
    x = state[:, 0]
    y = state[:, 1]
    
    features = [np.ones(n_samples)]
    
    for i in range(1, p + 1):
        for j in range(i + 1):
            # i \in [1, p] j \in [0, i]
            # {cos(\pi (x*(i-j) + y*j))}
            features.append(np.cos(np.pi * (x * (i - j) + y * j)))
            
    return np.column_stack(features)
 
def pos2fourierq(state, q=2):
    r"""
    state: (x,y) or [(x1,y1), (x2,y2), ...]
    q: cos(\pi(c_1 x + c_2 y)), and c_1, c_2 \in {0,...,q}
    
    return: fourier_basis: [N, (q+1)^2]
    - [1, cos(\pi x), cos(\pi y), cos(2\pi x), cos(\pi (x+y)), cos(2\pi y), ...]
    """
    if not isinstance(state, np.ndarray):
        state = np.array(state)
    if state.ndim == 1:
        state = state.reshape(1, -1)
    assert state.shape[-1] == 2
    
    x = state[:, 0]
    y = state[:, 1]
    
    features = []
    
    for c1 in range(0, q + 1):
        for c2 in range(0, q + 1):
            features.append(np.cos(np.pi * (c1 * x + c2 * y)))
            
    return np.column_stack(features)
 
def _test():
    # [1, x, y]
    print(pos2poly((1, 2), p=1))
 
    # [1, x, y, x^2, xy, y^2]
    print(pos2poly((1, 2), p=2))
    print(pos2poly([(1, 2), (3, 4)], p=2))
 
    # [1, cos(\pi x), cos(\pi y)]
    print(pos2fourier((1, 2), p=1))
 
    # [1, cos(\pi x), cos(\pi y), cos(2\pi x), cos(\pi (x+y)), cos(2\pi y)]
    print(pos2fourier((1, 2), p=2))
    print(pos2fourier([(1, 2), (3, 4)], p=2))
 
    # [1, cos(\pi x), cos(\pi y), cos(\pi (x+y))]
    print(pos2fourierq((1, 2), q=1))
 
def _test_fit_curve():
    """
    Test function to demonstrate fitting a curve using different basis functions
    """
    # Generate some sample data
    n_samples = 100
    x = np.linspace(0, 1, n_samples)
    y = np.linspace(0, 1, n_samples)
    grid_x, grid_y = np.meshgrid(x, y)
    points = np.column_stack((grid_x.flatten(), grid_y.flatten()))
    
    # Create a target function: z = sin(2πx) * cos(2πy)
    z = np.sin(2 * np.pi * grid_x) * np.cos(2 * np.pi * grid_y)
    z_flat = z.flatten()
    
    # Fit using polynomial basis
    poly_features = pos2poly(points, p=4)
    poly_weights = np.linalg.lstsq(poly_features, z_flat, rcond=None)[0]
    poly_pred = poly_features @ poly_weights
    poly_error = np.mean((poly_pred - z_flat) ** 2)
    
    # Fit using Fourier basis
    fourier_features = pos2fourier(points, p=4)
    fourier_weights = np.linalg.lstsq(fourier_features, z_flat, rcond=None)[0]
    fourier_pred = fourier_features @ fourier_weights
    fourier_error = np.mean((fourier_pred - z_flat) ** 2)
    
    # Fit using Fourier Q basis
    fourierq_features = pos2fourierq(points, q=4)
    fourierq_weights = np.linalg.lstsq(fourierq_features, z_flat, rcond=None)[0]
    fourierq_pred = fourierq_features @ fourierq_weights
    fourierq_error = np.mean((fourierq_pred - z_flat) ** 2)
    
    print(f"Polynomial basis MSE: {poly_error:.6f}")
    print(f"Fourier basis MSE: {fourier_error:.6f}")
    print(f"Fourier Q basis MSE: {fourierq_error:.6f}")
    
    # You can add visualization code here if matplotlib is available
    try:
        import matplotlib.pyplot as plt
        title = 'Comparison of Different Basis Functions for Function Approximation'
        fig = plt.figure(title, figsize=(15, 10))
        fig.suptitle(title, fontsize=16)
        
        # Common colormap and normalization for consistent colors
        cmap = 'viridis'
        norm = Normalize(vmin=z.min(), vmax=z.max())
 
        kwargs_3d = dict(cmap=cmap, alpha=0.9, norm=norm)
        kwargs_2d = dict(cmap=cmap, origin='upper', norm=norm)
        
        # Original function - 3D
        ax1 = fig.add_subplot(241, projection='3d')
        ax1.plot_surface(grid_x, grid_y, z, **kwargs_3d)
        ax1.set_title('Original (3D)')
        
        # Polynomial approximation - 3D
        ax2 = fig.add_subplot(242, projection='3d')
        ax2.plot_surface(grid_x, grid_y, poly_pred.reshape(n_samples, n_samples), **kwargs_3d)
        ax2.set_title(f'Polynomial p=4 (3D)\nMSE: {poly_error:.6f}')
        
        # Fourier approximation - 3D
        ax3 = fig.add_subplot(243, projection='3d')
        ax3.plot_surface(grid_x, grid_y, fourier_pred.reshape(n_samples, n_samples), **kwargs_3d)
        ax3.set_title(f'Fourier p=4 (3D)\nMSE: {fourier_error:.6f}')
        
        # Fourier Q approximation - 3D
        ax4 = fig.add_subplot(244, projection='3d')
        ax4.plot_surface(grid_x, grid_y, fourierq_pred.reshape(n_samples, n_samples), **kwargs_3d)
        ax4.set_title(f'Fourier Q q=4 (3D)\nMSE: {fourierq_error:.6f}')
        
        # Original function - 2D heatmap
        ax5 = fig.add_subplot(245)
        im5 = ax5.imshow(z.T, **kwargs_2d)
        ax5.set_title('Original (2D)')
        
        # Polynomial approximation - 2D heatmap
        ax6 = fig.add_subplot(246)
        im6 = ax6.imshow(poly_pred.reshape(n_samples, n_samples).T, **kwargs_2d)
        ax6.set_title(f'Polynomial p=4 (2D)')
        
        # Fourier approximation - 2D heatmap
        ax7 = fig.add_subplot(247)
        im7 = ax7.imshow(fourier_pred.reshape(n_samples, n_samples).T, **kwargs_2d)
        ax7.set_title(f'Fourier p=4 (2D)')
        
        # Fourier Q approximation - 2D heatmap
        ax8 = fig.add_subplot(248)
        im8 = ax8.imshow(fourierq_pred.reshape(n_samples, n_samples).T, **kwargs_2d)
        ax8.set_title(f'Fourier Q q=4 (2D)')
        
        # Add colorbar
        # plt.colorbar(im5, ax=[ax5, ax6, ax7, ax8], shrink=0.8)
        
        plt.tight_layout()
        plt.show()
    except ImportError:
        print("Matplotlib not available for visualization")
 
if __name__ == "__main__":
    _test()
    _test_fit_curve()