softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之获取本机ip数据包示例
Feb 10 Python
用Python代码来解图片迷宫的方法整理
Apr 02 Python
使用Python脚本将文字转换为图片的实例分享
Aug 29 Python
python数据结构之链表的实例讲解
Jul 25 Python
解决Python 中英文混输格式对齐的问题
Jul 16 Python
python pytest进阶之fixture详解
Jun 27 Python
python实现两张图片拼接为一张图片并保存
Jul 16 Python
python threading和multiprocessing模块基本用法实例分析
Jul 25 Python
浅析python 中大括号中括号小括号的区分
Jul 29 Python
Django框架HttpRequest对象用法实例分析
Nov 01 Python
python使用for...else跳出双层嵌套循环的方法实例
May 17 Python
Python 处理日期时间的Arrow库使用
Aug 18 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
ThinkPHP应用模式扩展详解
2014/07/16 PHP
ThinkPHP中的常用查询语言汇总
2014/08/22 PHP
php实现递归抓取网页类实例
2015/04/03 PHP
PHP实现事件机制实例分析
2015/06/26 PHP
PHP6新特性分析
2016/03/03 PHP
JS小功能(offsetLeft实现图片滚动效果)实例代码
2013/11/28 Javascript
jQuery中ajax的post()方法用法实例
2014/12/26 Javascript
JQuery中clone方法复制节点
2015/05/18 Javascript
详谈javascript异步编程
2016/02/21 Javascript
使用JavaScript实现ajax的实例代码
2016/05/11 Javascript
深入浅析knockout源码分析之订阅
2016/07/12 Javascript
vue组件父子间通信之综合练习(聊天室)
2017/11/07 Javascript
Vue的路由动态重定向和导航守卫实例
2018/03/17 Javascript
简述JS控制台的使用
2018/07/15 Javascript
element-ui表格数据转换的示例代码
2018/08/24 Javascript
详解vue后台系统登录态管理
2019/04/02 Javascript
微信小程序-form表单提交代码实例
2019/04/29 Javascript
基于小程序请求接口wx.request封装的类axios请求
2020/07/02 Javascript
Javascript数组及类数组相关原理详解
2020/10/29 Javascript
如何在Vue项目中添加接口监听遮罩
2021/01/25 Vue.js
[02:37]2018DOTA2亚洲邀请赛赛前采访 VP.no[o]ne心中最强SOLO是谁
2018/04/04 DOTA
python多线程抓取天涯帖子内容示例
2014/04/03 Python
使用python调用浏览器并打开一个网址的例子
2014/06/05 Python
在Python程序和Flask框架中使用SQLAlchemy的教程
2016/06/06 Python
Python实现计算圆周率π的值到任意位的方法示例
2018/05/08 Python
让你的Python代码实现类型提示功能
2019/11/19 Python
使用NumPy读取MNIST数据的实现代码示例
2019/11/20 Python
Python requests设置代理的方法步骤
2020/02/23 Python
Python基于Serializer实现字段验证及序列化
2020/11/04 Python
selenium如何定位span元素的实现
2021/01/13 Python
Expedia挪威官网:酒店、机票和租车
2018/03/03 全球购物
塑料制成的可水洗的编织平底鞋和鞋子:Rothy’s
2018/09/16 全球购物
哥德堡通行证:Gothenburg Pass
2019/12/09 全球购物
利用异或运算实现两个无符号数的加法运算
2013/12/20 面试题
2014年导购员工作总结
2014/11/18 职场文书
MySQL数据库Innodb 引擎实现mvcc锁
2022/05/06 MySQL