Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from facility_predict import Preprocess, Facility_Model, obj_Facility_Model, processor | |
| def predict_batch_from_csv(input_file, output_file): | |
| # Load batch data from CSV | |
| batch_data = pd.read_csv(input_file) | |
| # Initialize predictions list | |
| predictions = [] | |
| # Iterate over rows with tqdm for progress tracking | |
| for _, row in tqdm(batch_data.iterrows(), total=len(batch_data)): | |
| text = row['facility_name'] # Replace 'facility_name' with the actual column name containing the text data | |
| if pd.isnull(text): | |
| cleaned_text = "" | |
| else: | |
| cleaned_text = processor.clean_text(text) | |
| prepared_data = processor.process_tokenizer(cleaned_text) | |
| if cleaned_text == "": | |
| prediction = "" # Set prediction as empty string | |
| else: | |
| prediction = obj_Facility_Model.inference(prepared_data) | |
| predictions.append(prediction) | |
| # Create DataFrame for predictions | |
| output_data = pd.DataFrame({'prediction': predictions}) | |
| # Merge with input DataFrame | |
| pred_output_df = pd.concat([batch_data.reset_index(drop=True), output_data], axis=1) | |
| # Save predictions to CSV | |
| pred_output_df.to_csv(output_file, index=False) | |
| return "Prediction completed. Results saved to " + output_file | |
| # Define the Gradio interface | |
| input_csv = gr.inputs.File(label="Input CSV", type="file") | |
| output_csv = gr.outputs.File(label="Output CSV") | |
| # Define the prediction function for the Gradio interface | |
| def predict_interface(input_file): | |
| output_file = "./output.csv" | |
| predict_batch_from_csv(input_file.name, output_file) | |
| return output_file | |
| # Connect the interface with the prediction function | |
| iface = gr.Interface(fn=predict_interface, inputs=input_csv, outputs=output_csv, title="CSV Batch Prediction") | |
| # Run the interface | |
| iface.launch() | |