代码智能:问题与解法
代码智能:问题与解法
在基于预训练大模型引发自然语言处理革命的今天,代码智能技术也在迅速跟进发展。
那么,代码智能主要在做一些什么样的事情呢?可能很多同学会有比较科幻的想法,比如程序员要失业了之类的。
但是,其实很多工作并没有那么神秘,非常基础。那么我们用代码智能要解决什么问题呢?
- 判断两段代码是不是实现相似的功能
- 搜索跟当前代码段最相似的代码
- 检测代码是否有bug
- 自动修复代码中的bug
- 给一段代码自动写注释
- 根据文本推荐最相似的代码段
- 根据文本生成代码
看了之后是不是觉得更玄幻了?这么困难的问题怎么搞得定?
诚实地讲,这其中的每个子问题都很困难,就算是人类学习起来也很困难。
不过,正像是人类也是一步一步学会的一样,机器也在不断地进步。我们需要的不一定是万能的机器神,也是和我们一样普通的机器人,它们有很大的局限,但是它们可以帮助我们减轻不少工作量。
而且,最后一节我们将揭晓,处理这么多如此复杂问题的方法,却非常简单,一把梭哈,我们只用一个模型就能搞定。
下面我们就详细看一看这些问题的细节。
问题:克隆检测 Clone Detection
万地高楼平地起,代码智能任务首先从克隆检测做起。
所谓克隆检测,就是寻找写法和功能上相似的代码。
不要小看代码重复,它会显著地降低代码智能训练的有效性。
我们看下图,训练集中有重复,测试集中有重复,它们的交集中仍然有重复,在论文《The Adverse Effects of Code Duplication in Machine Learning Models of Code》中有详细的分析。
预测两段代码是否相似
以下的例子来自BigCloneBench数据集. 论文地址在:https://arxiv.org/pdf/2002.08653.pdf
下面我们举几个例子来看什么算相似:
代码1:
private StringBuffer encoder(String arg) {if (arg == null) {arg = "";}MessageDigest md5 = null;try {md5 = MessageDigest.getInstance("MD5");md5.update(arg.getBytes(SysConstant.charset));} catch (Exception e) {e.printStackTrace();}return toHex(md5.digest());}
代码2:
public String kodetu(String testusoila) {MessageDigest md = null;try {md = MessageDigest.getInstance("SHA");md.update(testusoila.getBytes("UTF-8"));} catch (NoSuchAlgorithmException e) {new MezuLeiho("Ez da zifraketa algoritmoa aurkitu", "Ados", "Zifraketa Arazoa", JOptionPane.ERROR_MESSAGE);e.printStackTrace();} catch (UnsupportedEncodingException e) {new MezuLeiho("Errorea kodetzerakoan", "Ados", "Kodeketa Errorea", JOptionPane.ERROR_MESSAGE);e.printStackTrace();}byte raw[] = md.digest();String hash = (new BASE64Encoder()).encode(raw);return hash;}
代码2的字符串是用巴斯克语写的。它们用的算法也有区别,判空和异常处理也有不同,但是我们认为它们是很类似的,属于克隆识别认为相同或高度相似的。
我们再看一对例子:
代码1:
public static void test(String args[]) {int trace;int bytes_read = 0;int last_contentLenght = 0;try {BufferedReader reader;URL url;url = new URL(args[0]);URLConnection istream = url.openConnection();last_contentLenght = istream.getContentLength();reader = new BufferedReader(new InputStreamReader(istream.getInputStream()));System.out.println(url.toString());String line;trace = t2pNewTrace();while ((line = reader.readLine()) != null) {bytes_read = bytes_read + line.length() + 1;t2pProcessLine(trace, line);}t2pHandleEventPairs(trace);t2pSort(trace, 0);t2pExportTrace(trace, new String("pngtest2.png"), 1000, 700, (float) 0, (float) 33);t2pExportTrace(trace, new String("pngtest3.png"), 1000, 700, (float) 2.3, (float) 2.44);System.out.println("Press any key to contiune read from stream !!!");System.out.println(t2pGetProcessName(trace, 0));System.in.read();istream = url.openConnection();if (last_contentLenght != istream.getContentLength()) {istream = url.openConnection();istream.setRequestProperty("Range", "bytes=" + Integer.toString(bytes_read) + "-");System.out.println(Integer.toString(istream.getContentLength()));reader = new BufferedReader(new InputStreamReader(istream.getInputStream()));while ((line = reader.readLine()) != null) {System.out.println(line);t2pProcessLine(trace, line);}} else System.out.println("File not changed !");t2pDeleteTrace(trace);} catch (MalformedURLException e) {System.out.println("MalformedURLException !!!");} catch (IOException e) {System.out.println("File not found " + args[0]);};}
代码2:
private static String loadUrlToString(String a_url) throws IOException {URL l_url1 = new URL(a_url);BufferedReader br = new BufferedReader(new InputStreamReader(l_url1.openStream()));String l_content = "";String l_ligne = null;l_content = br.readLine();while ((l_ligne = br.readLine()) != null) {l_content += AA.SL + l_ligne;}return l_content;}
这个虽然没有涉及小语种,但是明显代码长度差异巨大。不过,我们仍然认为它们是相似的。
我们看一对不相似的吧:
代码1:
private void setNodekeyInJsonResponse(String service) throws Exception {String filename = this.baseDirectory + service + ".json";Scanner s = new Scanner(new File(filename));PrintWriter fw = new PrintWriter(new File(filename + ".new"));while (s.hasNextLine()) {fw.println(s.nextLine().replaceAll("NODEKEY", this.key));}s.close();fw.close();(new File(filename + ".new")).renameTo(new File(filename));}
代码2:
public void transform(String style, String spec, OutputStream out) throws IOException {URL url = new URL(rootURL, spec);InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));transform(style, in, out);in.close();}
不相似的就不解释了。
BigCloneBench数据集,就是提供了两段代码,以及它们是否相似的人工打标的结果。
数据分为train.txt, valid.txt, test.txt三个集合,它们的格式都是同样的:
idx1 idx2 0/1
其中idx1和idx2是两段代码在data.jsonl中的索引值,最后一个是它们是否相似的人工打标的值。
代码都保存在data.jsonl中,格式为:
{"func":"代码","idx":"idx值"}
我们以训练集train.txt为例,其前两行是这样的:
13988825 8660836 0
80378 18548122 1
13988825在data.jsonl中对应的结构是这样的:
{"func": " private void setNodekeyInJsonResponse(String service) throws Exception {\n String filename = this.baseDirectory + service + \".json\";\n Scanner s = new Scanner(new File(filename));\n PrintWriter fw = new PrintWriter(new File(filename + \".new\"));\n while (s.hasNextLine()) {\n fw.println(s.nextLine().replaceAll(\"NODEKEY\", this.key));\n }\n s.close();\n fw.close();\n (new File(filename + \".new\")).renameTo(new File(filename));\n }\n", "idx": "13988825"}
8660836对应的是:
{"func": " public void transform(String style, String spec, OutputStream out) throws IOException {\n URL url = new URL(rootURL, spec);\n InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));\n transform(style, in, out);\n in.close();\n }\n", "idx": "8660836"}
而它们的结果是不相似。大家看到,这个例子就是刚才上面我们写的第三个例子。
搜索跟当前代码段语义最相似的代码段
这个我们使用北大李戈李师团队的POJ-104数据集。
这个数据集需要到https://drive.google.com/uc?id=0B2i-vWnOu7MxVlJwQXN6eVNONUU去下载。
每个代码段用一个index来描述,然后code字段是完整的代码。我们来看个例子:
{"label":"1","index":"0","code":"
int f(int a,int x)
{int count=1,i;for(i=x;i<a;i++)if(a%i==0)count+=f(a/i,i);if(i==a)return count;elsereturn 0;
}void main()
{int n,a;scanf(\"%d\",&n);for(;n>0;n--){scanf(\"%d\",&a);if(a==1||a==2)printf(\"1\
\");elseprintf(\"%d\
\",f(a,2));}
}
"}
然后,这个任务的目的就是求出针对某一段代码最相似的代码段。以取top 2为例:输出的样例如下:
{"index": "0", "answers": ["3", "2"]}
{"index": "1", "answers": ["0", "4"]}
{"index": "2", "answers": ["0", "1"]}
{"index": "4", "answers": ["1", "5"]}
{"index": "3", "answers": ["4", "2"]}
{"index": "5", "answers": ["4", "3"]}
也就是说,针对于代码index 0, 最相似的代码段是 index 3和2.
index 3是这样的:
void qut(int a,int b); //????
int num=0; //?????????
int main()
{int i,n,g[1000]; //?????????cin>>n;for(i=0;i<n;i++) //??????cin>>g[i];for(i=0;i<n;i++){qut(g[i],1); //????cout<<num<<endl;num=0;}return 0;
}void qut(int a,int b)
{int i;if (a>=b) {num++; if (b==1) b++;for (i=b;i<=a;i++) {if (a%i==0) {qut(a/i,i); //??a%i==0,??}}}
}
问题:缺陷检测
缺陷检测的数据集非常简单粗暴,就是一段打标的代码,标识是不是有漏洞。
我们看个有漏洞的例子:
{"project":"FFmpeg","commit_id":"aba232cfa9b193604ed98f3fa505378d006b1b3b","target":1,"func":"static int r3d_read_rdvo(AVFormatContext *s, Atom *atom){R3DContext *r3d = s->priv_data;AVStream *st = s->streams[0];int i;r3d->video_offsets_count = (atom->size - 8) / 4;r3d->video_offsets = av_malloc(atom->size);if (!r3d->video_offsets)return AVERROR(ENOMEM);for (i = 0; i < r3d->video_offsets_count; i++) {r3d->video_offsets[i] = avio_rb32(s->pb);if (!r3d->video_offsets[i]) {r3d->video_offsets_count = i;break;}av_dlog(s, \"video offset %d: %#x\
\", i, r3d->video_offsets[i]);}if (st->r_frame_rate.num)st->duration = av_rescale_q(r3d->video_offsets_count,(AVRational){st->r_frame_rate.den,st->r_frame_rate.num},st->time_base);av_dlog(s, \"duration %\"PRId64\"\
\", st->duration);return 0;}
","idx":5}
信息就这么多,至于哪行是什么问题,训练集中没有。
当然,数据集里大部分还是没有漏洞的,比如第一条:
{"project": "FFmpeg", "commit_id": "973b1a6b9070e2bf17d17568cbaf4043ce931f51", "target": 0, "func": "static av_cold int vdadec_init(AVCodecContext *avctx)\n\n{\n\n VDADecoderContext *ctx = avctx->priv_data;\n\n struct vda_context *vda_ctx = &ctx->vda_ctx;\n\n OSStatus status;\n\n int ret;\n\n\n\n ctx->h264_initialized = 0;\n\n\n\n /* init pix_fmts of codec */\n\n if (!ff_h264_vda_decoder.pix_fmts) {\n\n if (kCFCoreFoundationVersionNumber < kCFCoreFoundationVersionNumber10_7)\n\n ff_h264_vda_decoder.pix_fmts = vda_pixfmts_prior_10_7;\n\n else\n\n ff_h264_vda_decoder.pix_fmts = vda_pixfmts;\n\n }\n\n\n\n /* init vda */\n\n memset(vda_ctx, 0, sizeof(struct vda_context));\n\n vda_ctx->width = avctx->width;\n\n vda_ctx->height = avctx->height;\n\n vda_ctx->format = 'avc1';\n\n vda_ctx->use_sync_decoding = 1;\n\n vda_ctx->use_ref_buffer = 1;\n\n ctx->pix_fmt = avctx->get_format(avctx, avctx->codec->pix_fmts);\n\n switch (ctx->pix_fmt) {\n\n case AV_PIX_FMT_UYVY422:\n\n vda_ctx->cv_pix_fmt_type = '2vuy';\n\n break;\n\n case AV_PIX_FMT_YUYV422:\n\n vda_ctx->cv_pix_fmt_type = 'yuvs';\n\n break;\n\n case AV_PIX_FMT_NV12:\n\n vda_ctx->cv_pix_fmt_type = '420v';\n\n break;\n\n case AV_PIX_FMT_YUV420P:\n\n vda_ctx->cv_pix_fmt_type = 'y420';\n\n break;\n\n default:\n\n av_log(avctx, AV_LOG_ERROR, \"Unsupported pixel format: %d\\n\", avctx->pix_fmt);\n\n goto failed;\n\n }\n\n status = ff_vda_create_decoder(vda_ctx,\n\n avctx->extradata, avctx->extradata_size);\n\n if (status != kVDADecoderNoErr) {\n\n av_log(avctx, AV_LOG_ERROR,\n\n \"Failed to init VDA decoder: %d.\\n\", status);\n\n goto failed;\n\n }\n\n avctx->hwaccel_context = vda_ctx;\n\n\n\n /* changes callback functions */\n\n avctx->get_format = get_format;\n\n avctx->get_buffer2 = get_buffer2;\n\n#if FF_API_GET_BUFFER\n\n // force the old get_buffer to be empty\n\n avctx->get_buffer = NULL;\n\n#endif\n\n\n\n /* init H.264 decoder */\n\n ret = ff_h264_decoder.init(avctx);\n\n if (ret < 0) {\n\n av_log(avctx, AV_LOG_ERROR, \"Failed to open H.264 decoder.\\n\");\n\n goto failed;\n\n }\n\n ctx->h264_initialized = 1;\n\n\n\n return 0;\n\n\n\nfailed:\n\n vdadec_close(avctx);\n\n return -1;\n\n}\n", "idx": 0}
推理搞起来也是十分省事了,就是对应每个index给个0或1的结果:
0 0
1 1
2 1
3 0
4 0
问题:代码自动修复
有了识别代码漏洞的,更进一步就是学习自动修复代码的了。
代码自动修复的题目也很简单,一段是有bug的代码,另一段是修复之后的代码。
我们来看一个例子:
有bug的代码是这样的:
public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( VAR_1 . length ) - 1 ) ] . getTime ( ) ) ; }
修复之后是这样子的:
public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( type ) - 1 ) ] . getTime ( ) ) ; }
也真难为算法了,人看起来都有点费事。
问题:代码互译
比如实现C#语言和Java语言的互译。我们只要有一系列代码的C#写法和Java写法,就可以进行学习进行互译。
我们来看一对例子。
先看C#代码:
public virtual ListSpeechSynthesisTasksResponse ListSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request){var options = new InvokeOptions();options.RequestMarshaller = ListSpeechSynthesisTasksRequestMarshaller.Instance;options.ResponseUnmarshaller = ListSpeechSynthesisTasksResponseUnmarshaller.Instance;return Invoke<ListSpeechSynthesisTasksResponse>(request, options);
}
对应的Java
public ListSpeechSynthesisTasksResult listSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request) {request = beforeClientExecution(request);return executeListSpeechSynthesisTasks(request);
}
问题:给代码写注释
在训练素材中,有代码和注释,这个任务的目的为新代码写注释。评价指标是对于生成的注释的语言准确度。
这个我们使用CodeSearchNet数据集。
这个数据集中的每条记录的格式如下:
- repo: 仓库名
- path: 文件名
- func_name: 函数或方法名
- original_string: 未经处理的源字符串
- language: 编程语言
- code/function: 代码信息
- code_tokens/function_tokens: 分词之后的代码结果
- docstring: 注释字符串信息
- docstring_tokens: docstring分词之后的结果
- url: 自然语言的唯一标识号
- idx: 代码段的唯一标识号
我们来看个例子:
{"repo": "ciena-blueplanet/bunsen-core", "path": "src/reducer.js", "func_name": "", "original_string": "function
(state, action) {\n return _.defaults({\n isValidating: action.isValidating,\n lastAction: IS_VALIDA
TING\n }, state)\n }", "language": "javascript", "code": "function (state, action) {\n return _.defaults({\n isValidating: action.isValidating,\n lastAction: IS_VALIDATING\n }, state)\n }", "code_tokens":
["function", "(", "state", ",", "action", ")", "{", "return", "_", ".", "defaults", "(", "{", "isValidating", ":"
, "action", ".", "isValidating", ",", "lastAction", ":", "IS_VALIDATING", "}", ",", "state", ")", "}"], "docstrin
g": "Update is validating result\n@param {State} state - state to update\n@param {Action} action - action\n@retur
ns {State} - updated state", "docstring_tokens": ["Update", "is", "validating", "result"], "sha": "993c67e314e2b7
5003a1ff4c2f0cb667715562b2", "url": "https://github.com/ciena-blueplanet/bunsen-core/blob/993c67e314e2b75003a1ff4
c2f0cb667715562b2/src/reducer.js#L394-L399", "partition": "train"}
对于生成的自然语言,我们采用《ORANGE: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation 》论文的方法进行评分。
问题:为自然语言文本匹配最合适的代码段
我们仍然使用上一节的CodeSearchNet数据集。
这个搜索的结果类似于下面这样:
{"url": "url0", "answers": [10,11,12,13,14]}
{"url": "url1", "answers": [10,12,11,13,14]}
{"url": "url2", "answers": [13,11,12,10,14]}
{"url": "url3", "answers": [10,14,12,13,11]}
{"url": "url4", "answers": [10,11,12,13,14]}
配上UI,大致实现的效果是这样的:
或者是这样:
问题:根据自然语言生成代码
这是终极任务,就是根据一段文本描述硬生生地生成一段代码出来。
格式非常简单,就一段代码和一段文本。
我们来看个训练样本的例子:
{"code": "void function ( Binder arg0 ) { EventBus loc0 = new EventBus ( ) ; AmbariEventPublisher loc1 = new AmbariEventPublisher ( ) ; repla
ceEventBus ( AmbariEventPublisher . class , loc1 , loc0 ) ; arg0 . bind ( AmbariEventPublisher . class ) . toInstance ( loc1 ) ; }", "nl": "force the eventb us from ambarievent publisher to be serialand synchronous . concode_field_sep PlaceHolder placeHolder concode_field_sep void registerAlertListeners concode_elem_sep EventBus synchronizeAlertEventPublisher concode_elem_sep void replaceEventBus concode_elem_sep void registerAmbariListeners"}
这NL部分有点乱啊,没办法,为了增加数据量,没有那么多人手打精确的标。
我们再看一个:
{"code": "byte [ ] function ( Class < ? > arg0 , Configuration arg1 ) { return AuthenticationTokenSerializer . serialize ( org . apache . acc
umulo . core . client . mapreduce . lib . impl . ConfiguratorBase . getAuthenticationToken ( arg0 , arg1 ) ) ; }", "nl": "do n't use this . n
o , really , do n't use this . you already have an authenticationtoken with org.apache.accumulo.core.client.mapreduce.lib.impl.configuratorba
se #getauthenticationtoken class , configuration . you do n't need to construct it yourself . gets the password from the configuration . warn
ing : the password is stored in the configuration and shared with all mapreduce tasks ; it is base64 encoded to provide a charset safe conver
sion to a string , and is not intended to be secure . concode_field_sep PlaceHolder placeHolder concode_field_sep String getPrincipal concode
_elem_sep void setLogLevel concode_elem_sep Level getLogLevel concode_elem_sep Boolean isConnectorInfoSet concode_elem_sep String getTokenCla
ss concode_elem_sep void setZooKeeperInstance concode_elem_sep void setMockInstance concode_elem_sep Instance getInstance concode_elem_sep St
ring enumToConfKey concode_elem_sep void setConnectorInfo"}
是不是质量也没好到哪儿去?这就是CONCODE数据集的样子。
解法:基于大规模预训练模型的多任务学习
402年前,当努尔哈赤面临明朝多路大军的围困的时候,采取了“凭你几路来,我只一路去”的战术赢得了萨尔浒之战的立国之战。
我们同样学习古人的智慧,任你数据集千变万化,我们的工具就只用一个 - 大规模预训练模型。
下面是预训练模型的简要发展史:
以开头我们展示的微软的codebert模型为例,我们要处理上面最复杂的代码生成任务,只要一条命令就可以搞定:
python -m torch.distributed.launch --nproc_per_node=$PER_NODE_GPU run.py \--data_dir=$DATADIR \--langs=$LANG \--output_dir=$OUTPUTDIR \--pretrain_dir=$PRETRAINDIR \--log_file=$LOGFILE \--model_type=gpt2 \--block_size=512 \--do_train \--node_index 0 \--gpu_per_node $PER_NODE_GPU \--learning_rate=5e-5 \--weight_decay=0.01 \--evaluate_during_training \--per_gpu_train_batch_size=6 \--per_gpu_eval_batch_size=12 \--gradient_accumulation_steps=2 \--num_train_epochs=30 \--logging_steps=100 \--save_steps=5000 \--overwrite_output_dir \--seed=42
如果使用两张2 NVIDIA P100 GPU卡的话,22小时左右就可以训练完。
推理呢,也是一条语句就搞定:
python -u run.py \--data_dir=$DATADIR \--langs=$LANG \--output_dir=$OUTPUTDIR \--pretrain_dir=$PRETRAINDIR \--log_file=$LOGFILE \--model_type=gpt2 \--block_size=512 \--do_infer \--logging_steps=100 \--seed=42
只用一张P100卡,大约40分钟就可以搞定。
有了上面的基础,我们就可以去打比赛啦。上面介绍的数据集,全都是比赛的赛题:
上面提到的数据集,可以在https://github.com/microsoft/CodeXGLUE下载到。
欢迎来到代码智能的世界!
附录:快速上手指南
放翁云:纸上得来终觉浅,绝知此事要躬行。
下面我们就落地下,将代码智能模型的训练和推理跑起来~~~
- 第一步:安装transformers框架,因为codebert是基于这个框架的:
pip install transformers --user
- 第二步:安装PyTorch或者Tensorflow作为Transformers的后端,以2021年7月5日这个时间点,需要的PyTorch版本至少是1.5.0以上。驱动能搞定的话,索性就安装最新的吧:
pip install torch torchvision torchtext torchaudio --user
- 第三步,下载微软的数据集
git clone https://github.com/microsoft/CodeXGLUE
- 第四步,我们先玩玩BigCloneBench吧
到Code-Code/Clone-detection-BigCloneBench/code目录下,运行:
python run.py --output_dir=./saved_models --model_type=roberta --config_name=microsoft/codebert-base --model_name_or_path=microsoft/codebert-base --tokenizer_name=roberta-base --do_train --train_data_file=../dataset/train.txt --eval_data_file=../dataset/valid.txt --test_data_file=../dataset/test.txt --epoch 2 --block_size 400 --train_batch_size 16 --eval_batch_size 32 --learning_rate 5e-5 --max_grad_norm 1.0 --evaluate_during_training --seed 123456 2>&1| tee train.log
然后训练就运行起来了:
07/05/2021 16:29:24 - INFO - __main__ - ***** Running training *****
07/05/2021 16:29:24 - INFO - __main__ - Num examples = 90102
07/05/2021 16:29:24 - INFO - __main__ - Num Epochs = 2
07/05/2021 16:29:24 - INFO - __main__ - Instantaneous batch size per GPU = 8
07/05/2021 16:29:24 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 16
07/05/2021 16:29:24 - INFO - __main__ - Gradient Accumulation steps = 1
07/05/2021 16:29:24 - INFO - __main__ - Total optimization steps = 11264
在两张V100卡大约需要训练40分钟左右。
训练之后是验证,会将目前最好的结果保存到checkpoint中以备推理时使用
07/05/2021 17:10:04 - INFO - __main__ - ***** Running evaluation ***** 40950/41541 [00:10<00:00, 2785.61it/s]
07/05/2021 17:10:04 - INFO - __main__ - Num examples = 41541
07/05/2021 17:10:04 - INFO - __main__ - Batch size = 32
07/05/2021 17:16:05 - INFO - __main__ - ***** Eval results *****
07/05/2021 17:16:05 - INFO - __main__ - eval_f1 = 0.9531
07/05/2021 17:16:05 - INFO - __main__ - eval_precision = 0.9579
07/05/2021 17:16:05 - INFO - __main__ - eval_recall = 0.9484
07/05/2021 17:16:05 - INFO - __main__ - eval_threshold = 0.97
07/05/2021 17:16:06 - INFO - __main__ - ********************
07/05/2021 17:16:06 - INFO - __main__ - Best f1:0.9531
07/05/2021 17:16:06 - INFO - __main__ - ********************
07/05/2021 17:16:08 - INFO - __main__ - Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin
一次训练两轮,第二轮效果提升到0.97多:
07/05/2021 17:56:43 - INFO - __main__ - ***** Running evaluation ***** 40950/41541 [00:12<00:00, 3535.62it/s]
07/05/2021 17:56:43 - INFO - __main__ - Num examples = 41541
07/05/2021 17:56:43 - INFO - __main__ - Batch size = 32
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
07/05/2021 18:02:44 - INFO - __main__ - ***** Eval results *****
07/05/2021 18:02:44 - INFO - __main__ - eval_f1 = 0.9701
07/05/2021 18:02:44 - INFO - __main__ - eval_precision = 0.9772
07/05/2021 18:02:44 - INFO - __main__ - eval_recall = 0.9633
07/05/2021 18:02:44 - INFO - __main__ - eval_threshold = 0.97
07/05/2021 18:02:45 - INFO - __main__ - ********************
07/05/2021 18:02:45 - INFO - __main__ - Best f1:0.9701
07/05/2021 18:02:45 - INFO - __main__ - ********************
07/05/2021 18:02:47 - INFO - __main__ - Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin
然后我们用训好的模型进行推理吧:
python run.py \--output_dir=./saved_models \--model_type=roberta \--config_name=microsoft/codebert-base \--model_name_or_path=microsoft/codebert-base \--tokenizer_name=roberta-base \--do_eval \--do_test \--train_data_file=../dataset/train.txt \--eval_data_file=../dataset/valid.txt \--test_data_file=../dataset/test.txt \--epoch 2 \--block_size 400 \--train_batch_size 16 \--eval_batch_size 32 \--learning_rate 5e-5 \--max_grad_norm 1.0 \--evaluate_during_training \--seed 123456 2>&1| tee test.log
最后我们运行evaluator.py来查看测试结果:
python ../evaluator/evaluator.py -a ../dataset/test.txt -p saved_models/predictions.txt
输出如下:
{'Recall': 0.9677421599288263, 'Prediction': 0.9557057904236594, 'F1': 0.9616080550111168}
准确率0.956, 召回率0.968,还不错~
跟CodeXGLUE的排行榜比一比:
跟榜上的CodeBert的结果基本一致
GraphCodeBert
要提升性能,我们可以用GraphCodeBert来替换CodeBert.
我们先下载GraphCodeBert的代码:
git clone https://github.com/microsoft/CodeBERT
然后转到GraphCodeBERT/clonedetection目录,解压dataset.zip:
unzip dataset.zip
然后就可以像训练codebert一样训练graphcodebert了:
mkdir saved_models
python run.py \--output_dir=saved_models \--config_name=microsoft/graphcodebert-base \--model_name_or_path=microsoft/graphcodebert-base \--tokenizer_name=microsoft/graphcodebert-base \--do_train \--train_data_file=dataset/train.txt \--eval_data_file=dataset/valid.txt \--test_data_file=dataset/test.txt \--epoch 1 \--code_length 512 \--data_flow_length 128 \--train_batch_size 16 \--eval_batch_size 32 \--learning_rate 2e-5 \--max_grad_norm 1.0 \--evaluate_during_training \--seed 123456 2>&1| tee saved_models/train.log
上面的参数是按4个V100 GPU来调的,如果只有两块V100,可以将–code_length改成256.
CodeBert 40分钟左右一轮,GraphCodeBert大约需要6个半小时一轮。
然后我们进行推理:
python run.py --output_dir=saved_models --config_name=microsoft/graphcodebert-base --model_name_or_path=microsoft/graphcodebert-base --tokenizer_name=microsoft/graphcodebert-base --do_eval --do_test --train_data_file=dataset/train.txt --eval_data_file=dataset/valid.txt --test_data_file=dataset/test.txt --epoch 1 --code_length 256 --data_flow_length 128 --train_batch_size 16 --eval_batch_size 32 --learning_rate 2e-5 --max_grad_norm 1.0 --evaluate_during_training --seed 123456 2>&1| tee saved_models/test.log
最后我们解读一下结果吧:
python evaluator/evaluator.py -a dataset/test.txt -p saved_models/predictions.txt 2>&1| tee saved_models/score.log
结果如下:
{'Recall': 0.9589415798936043, 'Prediction': 0.962620653900429, 'F1': 0.9607703728051462}
代码智能:问题与解法相关推荐
- 微软发布代码智能新基准数据集CodeXGLUE,多角度衡量模型优劣
来源 | 微软研究院AI头条 编者按:代码智能(code intelligence)目的是让计算机具备理解和生成代码的能力,并利用编程语言知识和上下文进行推理,支持代码检索.补全.翻译.纠错.问答等场 ...
- 微软亚研院副院长周明:从语言智能到代码智能
11月6日上午,在中国中文信息学会和中国计算机学会联合创办的"语言与智能高峰论坛"上,微软亚洲研究院副院长周明,以<从语言智能到代码智能>为题,介绍了智能代码理解和生成 ...
- 微软亚洲研究院周明:从语言智能到代码智能
来源:NewBeeNLP本文约1600字,建议阅读5分钟本文介绍了微软亚洲研究院自然语言计算组在该研究领域的一系列最新进展. 微软亚洲研究院副院长周明老师报告:From Language Intell ...
- 后盾网lavarel视频项目---Laravel 安装代码智能提示扩展「laravel-ide-helper」
后盾网lavarel视频项目---Laravel 安装代码智能提示扩展「laravel-ide-helper」 一.总结 一句话总结: laravel-ide-helper作用是:代码提示 larav ...
- Atom JS 代码智能提示补全
JS 代码智能提示补全 题外话 官方正式版虽然内置了.autocomplete-plus:最为明显的一个功能就是记忆你已经输入过的名称进行匹配: 但是针对于某些语言来说,还是有些不足的-.其中 JS ...
- AIDE支持实时错误检查、代码重构、代码智能导航、生成APK
AIDE是一个Android Java集成开发环境,可以在Android系统内进行Android软件和游戏的开发.它不仅仅是一个编辑器,而是支持编写-编译-调试运行整个周期,开发人员可以在Androi ...
- ES6 import代码智能转换Babel插件: babel-plugin-imports-transform
babel-plugin-imports-transform ES6 import代码智能转换Babel插件,优化(webpack等)打包构建体积. Github地址: https://github. ...
- JAVA--AI编程助手【代码智能补全工具】盘点,让AI提高你的编程效率
1. 什么是AI编程助手 近几年,随着人工智能的迅速发展,AI在各行各业都有所应用. 特别是近两年,面向开发者的AI开发工具也是层出不穷,如GitHub Copilot.Codota.TabNine. ...
- 华为快应用IDE:代码智能提示及自动补全
代码编辑工具的代码智能提示/自动补全功能几乎是所有代码编写工具具备的基础功能. 华为快应用IDE自然不能少了如此便利的功能,Template模板.Script脚本.Style样式.Pair-Tages ...
- 数睿数据为代表的企业级无代码智能软件产业迎来新风口
文章来源于新华网 无代码平台作为灵活易用的应用构建工具大大提高了软件开发效率,提供了弹性.丰富的应变能力,可应对软件开发速度加快.动态时期变化增多等问题.随着我国数字化进程的推进,无代码智能软件产业迎 ...
最新文章
- 深入讨论.NET Socket的Accept方法
- docker安装mysql并配置,Docker安装MySql-挂载外部数据和配置
- JAVA: final 修饰符
- html中怎样引入外部字体文件路径,CSS引入外部字体
- 循环控制_break语句
- c ++ stl_获取列表的第一个和最后一个元素| C ++ STL
- POJ 2993 Emag eht htiw Em Pleh(模拟)
- 小程序获取城市经纬度_微信小程序获取当前所在城市的方法
- ansible批量安装服务器思路
- 【评分】个人作业——软件工程实践总结作业
- SQlite数据库的C编程接口(一) 简介 ——《Using SQlite》读书笔记
- Thymeleaf 模板布局三种区别
- IM即时通讯项目框架分析
- colorbox ajax,Colorbox弹出层插件
- sjtu1313 太湖旅行
- ggplot2读书笔记9:第六章 标度(二)
- 云笔记的使用感受和选择
- 成都Java开发前景怎么样?好找工作吗?
- php管理vsftp,Vsftp服务器配置文件详解
- 数据科学与计算机学院英文翻译,计算机系毕业论文中英文翻译英文