"""
Name: func.py
Author: Yuxiang LI (li.yuxiang.nj@gmail.com)
Data: 25/03/2016

Description: CPU version of neural network (very basic and slow)
"""

import numpy as np
import scipy.io as sio
from func import int_sqrt


def load_model(weights):
    """
    Load pre-trained model

    :param weights: mat file for model weights
    :return: denoising model
    """

    weight_file = 'weights_cvpr.mat'

    w = sio.loadmat(weight_file)['w'][0]
    
    def predict(box):
        a = np.ones((1,1))
        for i in range(len(w) - 1):
            box = np.tanh(np.mat(w[i])[:,:-1] * np.mat(box.reshape((w[i].shape[1] - 1, 1))) + np.mat(w[i])[:,-1])
        box = np.mat(w[-1])[:,:-1] * np.mat(box.reshape((w[-1].shape[1] - 1, 1))) + np.mat(w[-1])[:,-1]
        return np.array(box)
       
    return {'patch_in': int_sqrt(w[0].shape[1] - 1),
            'patch_out': int_sqrt(w[-1].shape[0]),
            'predict': predict}
