Source code for bibcat.llm.plots

import pathlib

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay

from bibcat import config
from bibcat.llm.io import read_output
from bibcat.llm.metrics import extract_eval_data, extract_roc_data, get_roc_metrics, prepare_roc_inputs
from bibcat.utils.logger_config import setup_logger

logger = setup_logger(__name__, level=config.logging.level)


# create a confusion matrix plot
[docs] def confusion_matrix_plot(summary_output_path: str | pathlib.Path, missions: list[str]) -> None: """Create a confusion matrix figure Create confusion matrix plots (counts and normalized) given a threshold value. Parameters ---------- summary_output_path: str | pathlib.Path the filepath of the evaluation *summary_output.json missions: list[str] list of the mission names to extract the classification labels. Returns ------- """ data = read_output(filename=summary_output_path) # capitalize all mission names just in case when is not missions = [mission.upper() for mission in missions] metrics_data = extract_eval_data(missions=missions, data=data) human = metrics_data["human_labels"] llm = metrics_data["llm_labels"] threshold = metrics_data["threshold"] human_llm_missions = metrics_data["human_llm_missions"] fig, ax = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True) papertypes = config.llms.papertypes # Absolute label counts ax[0].set_title("Count") ConfusionMatrixDisplay.from_predictions(human, llm, ax=ax[0], cmap=plt.cm.BuPu, colorbar=False, labels=papertypes) # Normalized confusion matrix ax[1].set_title("Normalized") ConfusionMatrixDisplay.from_predictions( human, llm, ax=ax[1], normalize="true", cmap=plt.cm.PuRd, colorbar=False, labels=papertypes ) for axis in ax: axis.set_xlabel("LLM label") axis.set_ylabel("Human label") # Leave more room above the subplots fig.subplots_adjust(top=0.75) # Suptitle fig.suptitle( f"Confusion Matrix at threshold = {threshold} ({config.llms.openai.model}) ", fontsize=14, fontweight="bold" ) if len(missions) == len(config.missions): fig.text( 0.5, 0.9, "All Missions considered", ha="center", fontsize=12, fontstyle="italic", color="gray", ) else: fig.text( 0.5, 0.9, f"Mission(s) considered: {', '.join(missions)}", ha="center", fontsize=12, fontstyle="italic", color="gray", ) fig.text( 0.5, 0.85, f"Mission(s) found: {', '.join(human_llm_missions)}", ha="center", fontsize=10, fontstyle="italic", color="gray", ) # plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Saving the figure cm_plot = ( pathlib.Path(config.paths.output) / f"llms/openai_{config.llms.openai.model}/{config.llms.cm_plot}_t{config.llms.performance.threshold}.png" ) plt.savefig(cm_plot) logger.info(f"The confusion matrix plot is saved on {cm_plot}!")
# create a ROC curve plot
[docs] def roc_plot(summary_output_path: str | pathlib.Path, missions: list[str]) -> None: """Create a Receiver Operating Characteristic (ROC) curve plot Parameters ---------- summary_output_path: str | pathlib.Path the filepath of the evaluation *summary_output.json missions: list[str] list of the mission names to extract the classification labels. Returns ------- """ # read the evaluation summary output file data = read_output(filename=summary_output_path) # capitalize all mission names just in case when is not missions = [mission.upper() for mission in missions] human_labels, llm_confidences, human_llm_missions = extract_roc_data(missions=missions, data=data) binarized_human_labels, llm_confidences, n_papertypes, n_verdicts = prepare_roc_inputs( human_labels, llm_confidences ) # compute ROC curve and ROC AUC (area under curve) for each class if n_papertypes > 2: fpr, tpr, thresholds, roc_auc, macro_roc_auc_ovr, micro_roc_auc_ovr = get_roc_metrics( llm_confidences, binarized_human_labels, n_papertypes ) else: fpr, tpr, thresholds, roc_auc = get_roc_metrics(llm_confidences, binarized_human_labels, n_papertypes) fig, ax = plt.subplots(1, 1, figsize=(8, 8)) bbox_args = dict(boxstyle="round", fc="0.8") if n_papertypes > 2: colors = plt.cm.viridis(np.linspace(0, 1, n_papertypes)) for i in range(n_papertypes): ax.plot( fpr[i], tpr[i], color=colors[i], lw=2, label=f"{config.llms.papertypes[i]} (AUC = {roc_auc[i]:.2f})", ) ax.annotate( f" : macro_roc_auc_ovr = {macro_roc_auc_ovr}\n micro_roc_auc_ovr = {micro_roc_auc_ovr}", xy=(1, 0.35), xycoords="axes fraction", xytext=(-10, -10), textcoords="offset points", ha="right", va="top", bbox=bbox_args, ) else: ax.plot(fpr, tpr, color="b", lw=2, label=f"SCIENCE (AUC={roc_auc:.2f})") # Define the target threshold values to mark target_thresholds = np.arange(0.1, 1.0, 0.1) # For each target threshold, find the index of the closest threshold in the computed array for p in target_thresholds: idx = np.abs(thresholds - p).argmin() ax.scatter(fpr[idx], tpr[idx], marker="o", color="r") ax.annotate( f"{thresholds[idx]:.1f}", (fpr[idx], tpr[idx]), textcoords="offset points", xytext=(5, 5), fontsize=8, color="r", ) # Plot the diagonal line ax.plot([0, 1], [0, 1], "k--", lw=2, label="Random guessing") ax.annotate( f"The number of verdicts : {n_verdicts}", xy=(1, 0.25), xycoords="axes fraction", xytext=(-10, -10), textcoords="offset points", ha="right", va="top", bbox=bbox_args, ) # dummy scatter for the thresholds label ax.scatter([], [], marker="o", color="red", label="Thresholds") ax.set_xlim([0.0, 1.0]) ax.set_ylim([0.0, 1.05]) ax.set_xlabel("False Positive Rate") ax.set_ylabel("True Positive Rate") ax.grid(True) ax.legend(loc="lower right") # Suptitle fig.suptitle("Reciever Operating Characteristic (ROC)", fontsize=14, fontweight="bold") if len(human_llm_missions) > 13: fig.text( 0.5, 0.9, "More than 12 MAST Missions", ha="center", fontsize=10, fontstyle="italic", color="gray", ) else: fig.text( 0.5, 0.9, f"Mission(s): {', '.join(human_llm_missions)}", ha="center", fontsize=10, fontstyle="italic", color="gray", ) plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Saving the figure roc = pathlib.Path(config.paths.output) / f"llms/openai_{config.llms.openai.model}/{config.llms.roc_plot}" plt.savefig(roc) logger.info(f"The roc plot is saved on {roc}!")