Source code for cerebras.modelzoo.data_preparation.nlp.pubmed.TextFormatting

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Script to format PubMed Fulltext commercial, PubMed Baseline and Update file Abstracts

Reference: https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT

"""

import csv
import glob
import os

import pubmed_parser as pmp


[docs]class TextFormatting:
[docs] def __init__( self, pubmed_path, output_filename, filesize_limit=5 * (10**9), recursive=False, ): """ :param str pubmed_path: Path to folder containing PubMed files :param str output_folder : Path to where the txt file to be written :param Optional[int] filesize_limit: Max size of each text file :param Optional[bool] recursive: Flag if true, searches for nxml/xml files recursively within subfolders """ self.pubmed_path = pubmed_path print(f"self.pubmed_path:{pubmed_path}") self.recursive = recursive self.filesize = int(filesize_limit) self.output_folder = os.path.dirname(output_filename) if not os.path.exists(self.output_folder): os.makedirs(self.output_folder) self.filename = output_filename
def merge_abstracts(self): file_num = 0 num_articles = 0 total_articles = 0 output_filename = ( self.filename.split('.')[0] + f"_{int(file_num)}" + ".txt" ) csv_file = output_filename.split('.')[0] + "_stats.csv" csv_fh = open(csv_file, 'w') fieldnames = ['fname', 'num_articles'] csv_writer = csv.DictWriter(csv_fh, fieldnames=fieldnames) csv_writer.writeheader() ofile = open(output_filename, mode='w', newline='\n') it = glob.iglob(self.pubmed_path + '/*.xml', recursive=self.recursive) for filename in it: print(f"Processing: {filename}") dicts_out = pmp.parse_medline_xml(filename) for dict_out in dicts_out: if not dict_out['abstract']: # Some articles have no abstract : https://pubmed.ncbi.nlm.nih.gov/13787/ continue try: for line in dict_out['abstract'].splitlines(): if len(line) < 30: # Refer to https://pubmed.ncbi.nlm.nih.gov/4969/ # Multiple paragraphs in abstract with subtitles such as "Result". # Removing these subtitles ONLY continue ofile.write(line.strip() + " ") ofile.write("\n\n") num_articles += 1 except: ofile.write("\n\n") continue if int(ofile.tell()) > self.filesize: ofile.close() # Write to csv stats: csv_writer.writerow( { 'fname': output_filename, 'num_articles': num_articles, } ) total_articles += num_articles # Open another file file_num += 1 output_filename = ( self.filename.split('.')[0] + f"_{int(file_num)}" + ".txt" ) print(f" -- Creating new file: {output_filename}") ofile = open(output_filename, mode='w', newline='\n') # Reset abstracts count per file num_articles = 0 total_articles += num_articles csv_writer.writerow( {'fname': output_filename, 'num_articles': num_articles} ) csv_writer.writerow( {'fname': 'Total abstracts', 'num_articles': total_articles} ) csv_fh.close() ofile.close() print(f"**** Total number of abstracts = {total_articles}") def merge_fulltext(self): # This puts one article per line file_num = 0 num_articles = 0 total_articles = 0 output_filename = ( self.filename.split('.')[0] + f"_{int(file_num)}" + ".txt" ) csv_file = output_filename.split('.')[0] + "_stats.csv" csv_fh = open(csv_file, 'w') fieldnames = ['fname', 'num_articles'] csv_writer = csv.DictWriter(csv_fh, fieldnames=fieldnames) csv_writer.writeheader() top_level_folders = [ os.path.join(self.pubmed_path, x) for x in os.listdir(self.pubmed_path) ] top_level_folders = [x for x in top_level_folders if os.path.isdir(x)] print(top_level_folders) not_written = os.path.join(self.output_folder, "exceptions.txt") with open(not_written, mode='w', newline='\n') as ex_fh: ofile = open(output_filename, mode='w', newline='\n') for folder in top_level_folders: it = glob.iglob(folder + '/**/*.nxml', recursive=self.recursive) for filename in it: print(f"Processing: {filename}") header_dict = pmp.parse_pubmed_xml(filename) body_list = pmp.parse_pubmed_paragraph( filename, all_paragraph=True ) if not header_dict and not body_list: ex_fh.write(filename) ex_fh.write('\n') continue try: if header_dict: ofile.write( header_dict['full_title'].strip() + ". " ) if header_dict.get('abstract', None): for line in header_dict['abstract'].splitlines(): if len(line) < 30: continue ofile.write(line.strip() + " ") if body_list: for dict_entry in body_list: section = dict_entry['section'] if len(section) > 30: ofile.write(section.strip() + ". ") for line in dict_entry['text'].splitlines(): ofile.write(line.strip() + " ") ofile.write("\n\n") num_articles += 1 except: ofile.write("\n\n") continue if int(ofile.tell()) > self.filesize: ofile.close() # Write to csv stats: csv_writer.writerow( { 'fname': output_filename, 'num_articles': num_articles, } ) total_articles += num_articles # Open another file file_num += 1 output_filename = ( self.filename.split('.')[0] + f"_{int(file_num)}" + ".txt" ) print(f" -- Creating new file: {output_filename}") ofile = open(output_filename, mode='w', newline='\n') # Reset articles count num_articles = 0 total_articles += num_articles csv_writer.writerow( { 'fname': output_filename, 'num_articles': num_articles, } ) csv_writer.writerow( { 'fname': 'Total num articles', 'num_articles': total_articles, } ) csv_fh.close() ofile.close() print( f"**** Total number of full text articles = {total_articles}" ) def merge(self, dataset_name): if ( dataset_name == "pubmed_baseline" or dataset_name == "pubmed_daily_update" ): self.merge_abstracts() elif ( dataset_name == "pubmed_fulltext" or dataset_name == "pubmed_open_access" ): self.merge_fulltext() else: raise ValueError(f"Incorrect dataset_name: {dataset_name} passed")