softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的turtle库使用详解
May 10 Python
Python read函数按字节(字符)读取文件的实现
Jul 03 Python
浅析PEP570新语法: 只接受位置参数
Oct 15 Python
python实现超市商品销售管理系统
Nov 22 Python
PyTorch和Keras计算模型参数的例子
Jan 02 Python
Python对象的属性访问过程详解
Mar 05 Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 Python
Python selenium模块实现定位过程解析
Jul 09 Python
详解Python中的文件操作
Jan 14 Python
python学习之使用Matplotlib画实时的动态折线图的示例代码
Feb 25 Python
pytorch 计算Parameter和FLOP的操作
Mar 04 Python
Pytorch中expand()的使用(扩展某个维度)
Jul 15 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
PHP4在WinXP下IIS和Apache2服务器上的安装实例
2006/10/09 PHP
解析PHP SPL标准库的用法(遍历目录,查找固定条件的文件)
2013/06/18 PHP
laravel安装和配置教程
2014/10/29 PHP
PHP中返回引用类型的方法
2015/04/03 PHP
php使用Jpgraph绘制饼状图的方法
2015/06/10 PHP
php实现构建排除当前元素的乘积数组方法
2018/10/06 PHP
确保Laravel网站不会被嵌入到其他站点中的方法
2019/10/18 PHP
PHP数据源架构模式之表入口模式实例分析
2020/01/23 PHP
TP5框架页面跳转样式操作示例
2020/04/05 PHP
jquery 框架使用教程 AJAX篇
2009/10/11 Javascript
再谈javascript图片预加载技术(详细演示)
2011/03/12 Javascript
用Javascript获取页面元素的具体位置
2013/12/09 Javascript
JavaScript Promise启示录
2014/08/12 Javascript
10分钟学会写Jquery插件实例教程
2014/09/06 Javascript
JavaScript计时器示例分析
2015/02/05 Javascript
jQuery控制文本框只能输入数字和字母及使用方法
2016/05/26 Javascript
基于Bootstrap实现下拉菜单项和表单导航条(两个菜单项,一个下拉菜单和登录表单导航条)
2016/07/22 Javascript
jquery封装插件时匿名函数形参和实参的写法解释
2017/02/14 Javascript
Vue.directive自定义指令的使用详解
2017/03/10 Javascript
详解node.js平台下Express的session与cookie模块包的配置
2017/04/26 Javascript
深入理解Node.js中通用基础设计模式
2017/09/19 Javascript
bootstrap 通过加减按钮实现输入框组功能
2017/11/15 Javascript
nuxt踩坑之Vuex状态树的模块方式使用详解
2019/09/06 Javascript
vue实现评价星星功能
2020/06/30 Javascript
Python 编码处理-str与Unicode的区别
2016/09/06 Python
TensorFlow实现创建分类器
2018/02/06 Python
python库skimage给灰度图像染色的方法示例
2020/04/27 Python
pandas 像SQL一样使用WHERE IN查询条件说明
2020/06/05 Python
解决redis与Python交互取出来的是bytes类型的问题
2020/07/16 Python
Python 远程开关机的方法
2020/11/18 Python
Superdry瑞典官网:英国日本街头风品牌
2017/05/17 全球购物
介绍一下你对SOA的认识
2016/04/24 面试题
行政助理求职自荐信
2013/10/26 职场文书
信息服务专业毕业生求职信
2014/03/02 职场文书
小学生国庆65周年演讲稿范文(2篇)
2014/09/21 职场文书
个人工作能力自我评价
2015/03/05 职场文书