这是一个非常有启发的例子,可以扩展到生产环境做一些模型!

public class PredictGenderTrain
{public String filePath;

    public static void main(String args[]){PredictGenderTrain dg = new PredictGenderTrain();//生成类实例,这种写法忘了叫什么了,故弄玄虚的感觉,谁知道告我一下,我喜欢主类里只有main函数的写法
        dg.filePath =  System.getProperty("user.dir") + "\\src\\main\\resources\\PredictGender\\Data";//找到数据路径
        dg.train();//调用train函数
    }/**
     * This function uses GenderRecordReader and passes it to RecordReaderDataSetIterator for further training.
     */
    public void train(){int seed = 123456;
        double learningRate = 0.01;
        int batchSize = 100;
        int nEpochs = 100;
        int numInputs = 0;
        int numOutputs = 0;
        int numHiddenNodes = 0;

        try(GenderRecordReader rr = new GenderRecordReader(new ArrayList<String>() {{add("M");add("F");}}))//这个try里面有小括号我也是头一次注意,括号里一般都是输入输出流,训练数据读取器作为临时变量,过后就会被自动回收,这里调用性别读取器类,后面会有这个类的详细解释{long st = System.currentTimeMillis();//打印当前时间
            System.out.println("Preprocessing start time : " + st);

            rr.initialize(new FileSplit(new File(this.filePath)));//初始化读取器

            long et = System.currentTimeMillis();//打印当前时间,处理时间
            System.out.println("Preprocessing end time : " + et);
            System.out.println("time taken to process data : " + (et-st) + " ms");

            numInputs = rr.maxLengthName * 5;  // multiplied by 5 as for each letter we use five binary digits like 00000//每个字符用5个二进制表示,输入大小就是最长名字的5倍
            numOutputs = 2;//输出大小为2
            numHiddenNodes = 2 * numInputs + numOutputs;//隐含层大小

            GenderRecordReader rr1 = new GenderRecordReader(new ArrayList<String>() {{add("M");add("F");}});//又搞了一个读取器

            DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, numInputs, 2);//训练迭代器
            System.out.println(trainIter);
            //System.exit(0);
            DataSetIterator testIter = new RecordReaderDataSetIterator(rr1, batchSize, numInputs, 2);//测试迭代器

            MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()//网络还是一样,假装自己是老司机.seed(seed).biasInit(1).regularization(true).l2(1e-4).iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate).updater(Updater.NESTEROVS).momentum(0.9)//采用梯度修正的参数修正方法.list().layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes).weightInit(WeightInit.XAVIER).activation("relu").build()).layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes).weightInit(WeightInit.XAVIER).activation("relu").build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes).nOut(numOutputs).build()).pretrain(false).backprop(true).build();

            MultiLayerNetwork model = new MultiLayerNetwork(conf);
            model.init();
            model.setListeners(new HistogramIterationListener(10));

            for ( int n = 0; n < nEpochs; n++){//按步走while(trainIter.hasNext()){//按批走model.fit(trainIter.next());//训练模型
                }trainIter.reset();//每步走完数据重新来
            }ModelSerializer.writeModel(model,this.filePath + "PredictGender.net",true);//通过模型序列化方法把模型写到指定路径

            System.out.println("Evaluate model....");
            Evaluation eval = new Evaluation(numOutputs);//评价模型,这个也是老套路了,最终打印评价矩阵
            while(testIter.hasNext()){DataSet t = testIter.next();
                INDArray features = t.getFeatureMatrix();
                INDArray lables = t.getLabels();
                INDArray predicted = model.output(features,false);

                eval.eval(lables, predicted);

            }//Print the evaluation statistics
            System.out.println(eval.stats());
        }catch(Exception e){System.out.println("Exception111 : " + e.getMessage());
        }}
}                                                           
public class GenderRecordReader extends LineRecordReader//性别读取器,把什么样的数据放入深度学习网络其实就是在建模,这里把名字中的字符都映射成二进制,这也决定了隐层的输入
{// list to hold labels passed in constructor
    private List<String> labels;//标签数组

    // Final list that contains actual binary data generated from person name, it also contains label (1 or 0) at the end
    private List<String> names = new ArrayList<String>();名字的二值数组,包括二值标签

    // String to hold all possible alphabets from all person names in raw data
    // This String is used to convert person name to binary string seperated by comma
    private String possibleCharacters = "";//来自于名字的字母表,用于把逗号分隔的名字转成的二进制数

    // holds length of largest name out of all person names
    public int maxLengthName = 0;//名字最长的长度

    // holds total number of names including both male and female names.
    // This variable is not used in PredictGenderTrain.java
    private int totalRecords = 0;//总共的名字数量

    // iterator for List "names" to be used in next() method
    private Iterator<String> iter;//名字迭代器

    /**
     * Constructor to allow client application to pass List of possible Labels//允许客户端程序传一串名字
     * @param labels - List of String that client application pass all possible labels, in our case "M" and "F"
     */
    public GenderRecordReader(List<String> labels)//传入标签的方法{this.labels = labels;
        //this.labels = this.labels.stream().map(element -> element + ".csv").collect(Collectors.toList());
        //System.out.println("labels : " + this.labels);
    }/**
     * returns total number of records in List "names"
     * @return - totalRecords
     */
    private int totalRecords()//返回名字数量{return totalRecords;
    }/**
     * This function does following steps//这函数做了一下几件事
1.找到具体路径的文件
2.文件以逗号分隔名字和性别
3.每个性别对应一个文件
4.把名字定位临时文件
5.把名字的字符转成二进制
6.合并每个名字所有字符的二进制
7.找出唯一字符表去产生二值字符串
8.从文件中取出等量的记录,保证数据平衡
9.这个函数使用java8的stream特征,只需不到1分钟,普通方式要处理5-7分钟,
10.找到转换后的二值文件
11.把名字列表设置成可迭代模式
     * - Looks for the files with the name (in specified folder) as specified in labels set in constructor
     * - File must have person name and gender of the person (M or F),
     *   e.g. Deepan,M
     *        Trupesh,M
     *        Vinay,M
     *        Ghanshyam,M
     *
     *        Meera,F
     *        Jignasha,F
     *        Chaku,F
     *
     * - File for male and female names must be different, like M.csv, F.csv etc.
     * - populates all names in temporary list
     * - generate binary string for each alphabet for all person names
     * - combine binary string for all alphabets for each name
     * - find all unique alphabets to generate binary string mentioned in above step
     * - take equal number of records from all files. To do that, finds minimum record from all files, and then takes
     *   that number of records from all files to keep balance between data of different labels.
     * - Note : this function uses stream() feature of Java 8, which makes processing faster. Standard method to process file takes more than 5-7 minutes.
     *          using stream() takes approximately 800-900 ms only.
     * - Final converted binary data is stored List<String> names variable
     * - sets iterator from "names" list to be used in next() function
     * @param split - user can pass directory containing .CSV file for that contains names of male or female//以性别命名文件的目录
     * @throws IOException
     * @throws InterruptedException
     */
函数把名字字符串转成二进制字符串,这是该算法的核心思路
1.从可能的字符集中寻找数字等价的字符
2.对每个字符生成二进制字符
3.用0补足5位
4.合并单个名字的二进制字符
5.右补0保证所有名字二进制长度一致
6.添加1,0标签/**
     * This function gives binary string for full name string
     * - It uses "PossibleCharacters" string to find the decimal equivalent to any alphabet from it
     * - generate binary string for each alphabet
     * - left pads binary string for each alphabet to make it of size 5
     * - combine binary string for all alphabets of a name
     * - Right pads complete binary string to make it of size that is the size of largest name to keep all name length of equal size
     * - appends label value (1 or 0 for male or female respectively)
     * @param name - person name to be converted to binary string
     * @param gender - variable to decide value of label to be added to name's binary string at the end of the string
     * @return
     */
    private String getBinaryString(String name, int gender){String binaryString = "";
        for (int j = 0; j < name.length(); j++)//对每个名字,遍历每个字符,从字符集中找到索引,把索引转成二进制,并补足5位0{String fs = org.apache.commons.lang3.StringUtils.leftPad(Integer.toBinaryString(this.possibleCharacters.indexOf(name.charAt(j))),5,"0");
            binaryString = binaryString + fs;
        }//binaryString = String.format("%-" + this.maxLengthName*5 + "s",binaryString).replace(' ','0'); // this takes more time than StringUtils, hence commented

        binaryString  = org.apache.commons.lang3.StringUtils.rightPad(binaryString,this.maxLengthName*5,"0");//这名字处理完了,要保证最大长度一致,右补0,比如某人名字是一个字符串,最长是两个字符串,缺的就补0
        binaryString = binaryString.replaceAll(".(?!$)", "$0,");//这里是一个鬼畜般的用法,老衲也是查了半天,$是结束符,
?!$代表不是结束符,.(?!$)代表不是结束符的一个字符,$0是找到这个字符, 整个的意思是只要没到结束,把每个字符替换成这个字符后面加逗号,这样就把输入分开了
        //System.out.println("binary String : " + binaryString);        return binaryString + "," + String.valueOf(gender);    }}
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException//由于继承线性读取器,需要重写各方法
{if(split instanceof FileSplit)//如果split是FileSplit的实例,注意FileSplit继承BaseInputSplit,BaseInputSplit继承
InputSplit,split是InputSplit类{URI[] locations = split.locations();//文件定位,感觉方法还是挺全的

        System.out.println(locations[0]);

        if(locations != null && locations.length >= 1)//至少有俩文件{String longestName = "";//最长名字
            String uniqueCharactersTempString = "";//唯一字符
            List<Pair<String, List<String>>> tempNames = new ArrayList<Pair<String, List<String>>>();//临时名字数组
            for(URI location : locations){//遍历每个路径File file = new File(location);//路径对应文件夹

                List<String> temp  = this.labels.stream().filter(line -> file.getName().equals(line + ".csv")).collect(Collectors.toList());//这明明就是我最喜欢的scala的写法啊,过滤文件夹下名为性别的csv文件,组成数组
                if(temp.size() > 0)//要有文件{java.nio.file.Path path = Paths.get(file.getAbsolutePath());//找到路径
                    List<String> tempList = java.nio.file.Files.readAllLines(path, Charset.defaultCharset()).stream().map(element -> element.split(",")[0]).collect(Collectors.toList());//又是scala写法,按行读取文件夹下所有数据,并按逗号切分,取出第一列也就是名字构成数组

                    Optional<String> optional = tempList.stream().reduce((name1, name2)->name1.length() >= name2.length() ? name1 : name2);//还是scala,求出最长名字

                    if (optional.isPresent() && optional.get().length() > longestName.length())//还是Scala方法,
.isPresent()相当于scala Option的some(),也就是不为空且比最长的还长longestName = optional.get();//赋值给最长字符串

                    uniqueCharactersTempString = uniqueCharactersTempString + tempList.toString();//把名字数组转成字符串
                    Pair<String,List<String>> tempPair = new Pair<String,List<String>>(temp.get(0),tempList);
                    tempNames.add(tempPair);//把文件名和名字数组构成pair装入tempNames数组
                }else
                    throw new InterruptedException("File missing for any of the specified labels");//没文件报错
            }this.maxLengthName = longestName.length();//赋值最大长度

            String unique = Stream.of(uniqueCharactersTempString).map(w -> w.split("")).flatMap(Arrays::stream).distinct().collect(Collectors.toList()).toString();//求名字字符串的唯一字符,详细的不说了,都是类似scala语法

            char[] chars = unique.toCharArray();//唯一字符转成字符数组
            Arrays.sort(chars);//升序排列字符数组

            unique = new String(chars);//再转成字符串
            unique = unique.replaceAll("\\[", "").replaceAll("\\]","").replaceAll(",","");//去掉方括号逗号
            if(unique.startsWith(" "))unique = " " + unique.trim();//如果是tab,whithspace开头,去掉

            this.possibleCharacters = unique;//赋值给唯一字符串

            Pair<String, List<String>> tempPair = tempNames.get(0);//拿出第一个文件
            int minSize = tempPair.getValue().size();//计算文件数据量
            for(int i=1;i<tempNames.size();i++)//循环找到最小的数据量{if (minSize > tempNames.get(i).getValue().size())minSize = tempNames.get(i).getValue().size();
            }List<Pair<String, List<String>>> oneMoreTempNames = new ArrayList<Pair<String, List<String>>>();
            for(int i=0;i<tempNames.size();i++)//循环文件{int diff = Math.abs(minSize - tempNames.get(i).getValue().size());//看每个文件数据量比最小的多多少
                List<String> tempList = new ArrayList<String>();

                if (tempNames.get(i).getValue().size() > minSize) {如果比最小的大,只取最小长度的数据tempList = tempNames.get(i).getValue();
                    tempList = tempList.subList(0, tempList.size() - diff);
                }else
                    tempList = tempNames.get(i).getValue();//如果一样长保持不变
                Pair<String, List<String>> tempNewPair = new Pair<String, List<String>>(tempNames.get(i).getKey(),tempList);
                oneMoreTempNames.add(tempNewPair);//这样所有文件数据量都一样了
            }tempNames.clear();

            List<Pair<String, List<String>>> secondMoreTempNames = new ArrayList<Pair<String, List<String>>>();

            for(int i=0;i<oneMoreTempNames.size();i++)//遍历刚才的数组{int gender = oneMoreTempNames.get(i).getKey().equals("M") ? 1 : 0;//给M编号1,F编号0
                List<String> secondList = oneMoreTempNames.get(i).getValue().stream().map(element -> getBinaryString(element.split(",")[0],gender)).collect(Collectors.toList());//把名字转成二进制,并加上新编的类别号
                Pair<String,List<String>> secondTempPair = new Pair<String, List<String>>(oneMoreTempNames.get(i).getKey(),secondList);
                secondMoreTempNames.add(secondTempPair);//放入数组
            }oneMoreTempNames.clear();//清空

            for(int i=0;i<secondMoreTempNames.size();i++){names.addAll(secondMoreTempNames.get(i).getValue());//把所有文件名加到二进制名字数组
            }secondMoreTempNames.clear();//清空
            this.totalRecords = names.size();//二进制名字总数
            Collections.shuffle(names);//shuffle
            this.iter = names.iterator();//变成迭代器
        }else
            throw new InterruptedException("File missing for any of the specified labels");
    }else if (split instanceof InputStreamInputSplit){System.out.println("InputStream Split found...Currently not supported");
        throw new InterruptedException("File missing for any of the specified labels");
    }
}/**
 * - takes onme record at a time from names list using iter iterator
 * - stores it into Writable List and returns it
 *
 * @return
 */
@Override
public List<Writable> next()//复写next方法,逗号分隔把数值转成小数且是一个可写的列表
{if (iter.hasNext()){List<Writable> ret = new ArrayList<>();
        String currentRecord = iter.next();
        String[] temp = currentRecord.split(",");
        for(int i=0;i<temp.length;i++){ret.add(new DoubleWritable(Double.parseDouble(temp[i])));
        }return ret;
    }else
        throw new IllegalStateException("no more elements");
}@Override
public boolean hasNext()//复写hasNext
{if(iter != null) {return iter.hasNext();
    }throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}@Override
public void close() throws IOException {}@Override
public void setConf(Configuration conf) {this.conf = conf;
}@Override
public Configuration getConf() {return conf;
}@Override
public void reset()//复写reset,把保存的名字赋给迭代器
{this.iter = names.iterator();
}

深度学习-根据名字识别男女相关推荐

  1. 基于深度学习的人脸识别系统系列(Caffe+OpenCV+Dlib)——【六】设计人脸识别的识别类...

    前言 基于深度学习的人脸识别系统,一共用到了5个开源库:OpenCV(计算机视觉库).Caffe(深度学习库).Dlib(机器学习库).libfacedetection(人脸检测库).cudnn(gp ...

  2. OCR技术系列之四】基于深度学习的文字识别(3755个汉字)(转)

    上一篇提到文字数据集的合成,现在我们手头上已经得到了3755个汉字(一级字库)的印刷体图像数据集,我们可以利用它们进行接下来的3755个汉字的识别系统的搭建.用深度学习做文字识别,用的网络当然是CNN ...

  3. 基于深度学习的人脸识别与管理系统(UI界面增强版,Python代码)

    摘要:人脸检测与识别是机器视觉领域最热门的研究方向之一,本文详细介绍博主自主设计的一款基于深度学习的人脸识别与管理系统.博文给出人脸识别实现原理的同时,给出Python的人脸识别实现代码以及PyQt设 ...

  4. dlib 使用OpenCV,Python和深度学习进行人脸识别 源代码

    请直接访问原文章 dlib 使用OpenCV,Python和深度学习进行人脸识别 源代码 https://hotdog29.com/?p=595 在 2019年7月7日 上张贴 由 hotdog发表回 ...

  5. 【OCR技术系列之四】基于深度学习的文字识别(3755个汉字)

    上一篇提到文字数据集的合成,现在我们手头上已经得到了3755个汉字(一级字库)的印刷体图像数据集,我们可以利用它们进行接下来的3755个汉字的识别系统的搭建.用深度学习做文字识别,用的网络当然是CNN ...

  6. python模块cv2人脸识别_手把手教你使用OpenCV,Python和深度学习进行人脸识别

    使用OpenCV,Python和深度学习进行人脸识别 在本教程中,你将学习如何使用OpenCV,Python和深度学习进行面部识别.首先,我们将简要讨论基于深度学习的面部识别,包括"深度度量 ...

  7. 基于深度学习的口罩识别与检测PyTorch实现

    基于深度学习的口罩识别与检测PyTorch实现 1. 设计思路 1.1 两阶段检测器:先检测人脸,然后将人脸进行分类,戴口罩与不戴口罩. 1.2 一阶段检测器:直接训练口罩检测器,训练样本为人脸的标注 ...

  8. 用OpenCV和深度学习进行年龄识别

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达本文转自|机器学习算法那些事 在本教程中,您将学习如何使用OpenC ...

  9. 基于深度学习的脑电图识别 综述篇(三)模型分析

    作者|Memory逆光 本文由作者授权分享 导读 脑电图(EEG)是一个复杂的信号,一个医生可能需要几年的训练并利用先进的信号处理和特征提取方法,才能正确解释其含义.而如今机器学习和深度学习的发展,大 ...

最新文章

  1. mysql insert union_在MySQL中使用INSERT INTO SELECT和UNION执行多次插入
  2. DevOps - Spring Boot自动部署到WebLogic
  3. RESTful Web 服务 - 缓存
  4. Python入门实战题目
  5. 他是世界首位惯性导航博士!如今101岁,依然对航天事业激情澎湃
  6. 实现APP-V服务全程跟踪(二)
  7. (原创)如何进行有符号小数乘法运算?(Verilog)
  8. ASP.NET网站SESSION丢失的问题
  9. 利用java反射原理写了一个简单赋值和取值通用类【改】
  10. dockerfile安装jenkins 并配置构建工具(node、npm、maven、git)
  11. SQLServer 2008 r2 下载地址(百度云)及安装图解
  12. QQ的DLL文件修改大全!
  13. PHP数字金额转换成中文大写金额
  14. 当企业网站跳出率超过70%,就要查找原因改进了
  15. windows各版本序列号集合
  16. 基于ThreeJS的3D地球
  17. python类的魔法方法和装饰器
  18. 打开word时提示需要安装包gaozhi.msi
  19. 计算机网络中删除自己的共享,如何删除我的电脑中共享文档
  20. [PCB]这里带你了解何为PCB?

热门文章

  1. MySQL 数据表优化设计(六):id 该如何选择数据类型?
  2. python发送邮件时报: Error: need RCPT command
  3. matlab二维图像重采样,使用网格插值对图像重采样
  4. Heap size 80869K exceeds notification threshold (51200K)
  5. html旅游门票源代码,票务网站整套静态模板 HTML模板
  6. MVP架构开发的鼠绘漫画客户端
  7. ESP8266 复位 ets Jan 8 2013,rst cause:4, boot mode:(3,7)
  8. 爬虫练习-爬取《斗破苍穹》全文小说
  9. 利用派生类实现统一接口解决三种基础排序问题
  10. 移动医疗应用遍地开花,却抓不住用户的核心需求