# data_preprocessing # # created by LuYF-Lemon-love on October 31, 2022 # # 该脚本为 TransE 生成数据集 # # prerequisites: # ../origin_data/raw_data.csv # # 输出最终的数据 # output: # ../final_data/relation2id.txt # ../final_data/entity2id.txt # ../final_data/train2id.txt # ../final_data/valid2id.txt # ../final_data/test2id.txt
操作系统:Ubuntu 20.04.5 LTS
生成目录
1
$ mkdir -p ../final_data
导入第三方库
1 2 3
import numpy as np import pandas as pd import random
读取原始数据
1
df = pd.read_csv('../origin_data/raw_data.csv')
1 2 3
# 去掉 '病理', '诊断', '预防' 三列
df = df.loc[:, [column for column in df.columns if column notin ['病理', '诊断', '预防']]]
生成 relation2id.txt
1 2 3 4 5 6 7
relation2id = {} f = open('../final_data/relation2id.txt', 'w') f.write('%d\n' % (len(df.columns[1:]))) forid, relation inenumerate(df.columns[1:]): f.write("%s\n" % relation) relation2id[relation] = id f.close()
生成 entity2id.txt
1 2 3 4 5 6 7 8 9 10 11 12 13
entitys = set() triples = []
for index, Series in df.iterrows(): head = Series['疾病名称'].replace(' ', '-') for column_name in df.columns[1:]: if Series[column_name] isnot np.nan: for tail in Series[column_name].strip(' ;').split(';'): if tail != '': tail = ''.join([ch for ch in tail if ch notin [' ', '\t', '\n', '\r']]) entitys.add(tail) triples.append([head, tail, column_name]) entitys.add(head)
1 2 3 4 5 6 7
entity2id = {} f = open('../final_data/entity2id.txt', 'w') f.write('%d\n' % (len(entitys))) forid, entity inenumerate(list(entitys)): f.write('%s\n' % entity) entity2id[entity] = id f.close()
shuffle 数据集
1 2 3
random.seed(42) random.shuffle(triples) total = len(triples)