File size: 21,262 Bytes
33dc3b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
import marimo

__generated_with = "0.14.13"
app = marimo.App(
    width="full",
    app_title="Text classification using Logistic Regression",
)

with app.setup:
    import glob

    import altair as alt
    import eli5
    import marimo as mo
    import numpy as np
    import pandas as pd
    from eli5 import format_as_html
    from sklearn.calibration import calibration_curve
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import (
        brier_score_loss,
        classification_report,
        confusion_matrix,
    )
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import label_binarize


@app.cell
def _():
    mo.md(
        r"""
    # テキスト分類モデルの可視化と解釈

    このノートブックでは、テキスト分類モデルの学習と解釈方法をインタラクティブに探求します。
    [scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#logisticregression)のロジスティック回帰モデルを使用し、ELI5ライブラリで予測の説明を可視化します。
    """
    )
    return


@app.cell
def _():
    mo.md(
        r"""
    ## ロジスティック回帰の内部(テキスト分類向け)

    **二値分類(クラスが2つ)**

    $$
    s=\mathbf{w}^\top\mathbf{x}+b
    =\sum_{i=1}^{d} w_i\,x_i + b
    = w_1x_1+w_2x_2+\cdots+w_dx_d+b.
    $$

    **多クラス(クラス $c$ のスコア)**

    $$
    s_c=\mathbf{w}_c^\top\mathbf{x}+b_c
    =\sum_{i=1}^{d} w_{c,i}\,x_i + b_c
    = w_{c,1}x_1+w_{c,2}x_2+\cdots+w_{c,d}x_d+b_c.
    $$

    **全クラス同時表示(行列形)**

    $$
    \mathbf{s}=W\mathbf{x}+\mathbf{b},\quad
    W\in\mathbb{R}^{K\times d},\ \mathbf{s}\in\mathbb{R}^{K}.
    $$

    **テキスト分類の対応づけ(例)**

    $$
    x_i=\mathrm{tfidf}(\text{term}_i,\ \text{doc}),\qquad
    \text{contrib}_{i\to c}=x_i\,w_{c,i}.
    $$

    **二値の確率化(シグモイド)**

    $$
    p(y=1\mid \mathbf{x})=\sigma(s)=\frac{1}{1+e^{-s}},\quad
    \log\frac{p}{1-p}=s
    $$

    **多クラスの確率化(ソフトマックス)**

    $$
    p(y=c\mid \mathbf{x})=\frac{e^{s_c}}{\sum_{k} e^{s_k}},\quad
    s_c=\mathbf{w}_c^\top \mathbf{x}+b_c
    $$

    **決定境界と閾値**
    二値では$s=0$が閾値(通常は$p=0.5$)。
    多クラスでは$\{s_c\}$の最大を選ぶ。

    **学習と正則化**

    $$
    \min_{\{\mathbf{w}_c,b_c\}}
    \left[-\sum_{i}\log p(y_i\mid \mathbf{x}_i)\right]
    +\frac{\lambda}{2}\sum_{c}\|\mathbf{w}_c\|_2^2
    $$

    scikit-learnの$C$は$\lambda$の逆数に相当($C$が小さければ$\Rightarrow$正則化強)。
    """
    )
    return


@app.cell
def _():
    mo.md(
        """
    ## データセットの確認

    4つの作家の小説から構成される小さなテキストコーパスです。
    各作品からは先頭100トークンのチャンク100個があり、トークンはレマ化されています。

    - A. Merritt: The Moon Pool (Science fiction)
    - E. R. Eddison: The Worm Ouroboros (Fantasy)
    - H. G. Wells: The Wonderful Visit (Fantasy)
    - Mark Twain: A Connecticut Yankee in King Arthur's Court (Historical fiction/Fantasy)

    (StandardEbooksより)
    """
    )
    return


@app.cell
def _():
    rng = np.random.default_rng(42)

    df = pd.DataFrame(columns=["text", "label"])
    for csv_file in glob.glob("./*.csv"):
        d = pd.read_csv(csv_file)
        df = pd.concat([df, d]).reset_index(drop=True)

    mo.vstack([mo.md("### コーパス"), df])
    return (df,)


@app.cell
def _():
    mo.md(
        """
    ## ハイパーパラメータの調整

    以下のスライダーとドロップダウンでモデルのハイパーパラメータを調整できます:

    - **テスト比率**: 訓練データとテストデータの分割比率
    - **最大n-gram**: 使用するn-gramの最大長(1=単語、2=連続する2単語、など)
    - **最小出現回数**: 特徴量として扱う単語の最小出現回数
    - **正則化C**: ロジスティック回帰の正則化強度(小さいほど強い正則化)
    """
    )
    return


@app.cell
def _():
    split = mo.ui.slider(0.1, 0.9, value=0.3, step=0.05, label="テスト比率")
    max_ng = mo.ui.slider(1, 4, value=2, step=1, label="最大n-gram")
    min_df = mo.ui.slider(1, 10, value=1, step=1, label="最小出現回数")
    C_pick = mo.ui.dropdown(options=[0.01, 0.1, 1.0, 10.0], value=1.0, label="正則化C")
    k_top = mo.ui.slider(5, 40, value=20, step=5, label="表示上位語数")
    cls_for_cal = mo.ui.dropdown(options=["All"], value="All", label="クラス(信頼性曲線)")
    mo.vstack([split, max_ng, min_df, C_pick, k_top, cls_for_cal])
    return C_pick, cls_for_cal, max_ng, min_df, split


@app.cell
def _(C_pick, df, max_ng, min_df, split):
    for _ in mo.status.progress_bar(
        range(10),
        title="Training Logistic Regression model",
        subtitle="Please wait",
        show_eta=True,
        show_rate=True
    ):
        X_train_text, X_test_text, y_train, y_test = train_test_split(
            df["text"].tolist(),
            df["label"].tolist(),
            test_size=split.value,
            random_state=0,
            stratify=df["label"].tolist(),
        )
        vec = TfidfVectorizer(ngram_range=(1, max_ng.value), min_df=min_df.value)
        X_train = vec.fit_transform(X_train_text)
        X_test = vec.transform(X_test_text)
        clf = LogisticRegression(
            C=float(C_pick.value), max_iter=2000, solver="lbfgs"
        )
        clf.fit(X_train, y_train)
        classes = clf.classes_
    return X_test, X_test_text, X_train, classes, clf, vec, y_test, y_train


@app.cell
def _(X_test, classes, clf, vec, y_test):
    feature_names = vec.get_feature_names_out()
    W = clf.coef_
    b = clf.intercept_
    weights_df = (
        pd.DataFrame(W, index=classes, columns=feature_names)
        .stack()
        .rename("weight")
        .reset_index()
        .rename(columns={"level_0": "class", "level_1": "term"})
    )
    y_pred = clf.predict(X_test)
    y_prob = clf.predict_proba(X_test)
    probs_df = pd.DataFrame(y_prob, columns=classes)
    cm = (
        pd.DataFrame(
            confusion_matrix(y_test, y_pred, labels=classes),
            index=classes,
            columns=classes,
        )
        .reset_index()
        .melt(id_vars="index", var_name="pred", value_name="count")
        .rename(columns={"index": "true"})
    )
    heat = (
        alt.Chart(cm)
        .mark_rect()
        .encode(
            x="pred:N",
            y="true:N",
            color=alt.Color("count:Q", scale=alt.Scale(scheme="blues")),
        )
        .properties(title="混同行列", width=300, height=300)
    )
    text = (
        alt.Chart(cm)
        .mark_text(baseline="middle")
        .encode(
            x="pred:N",
            y="true:N",
            text="count:Q",
            color=alt.condition(
                alt.datum.count > 5,
                alt.value("white"),
                alt.value("black")
            )
        )
    )
    cm_chart = alt.layer(heat, text).resolve_scale(color="independent")
    mo.ui.altair_chart(cm_chart)
    return weights_df, y_pred


@app.cell
def _():
    mo.md(
        """
    ## 重みと文書頻度の関係

    各クラスに対する特徴語(単語/n-gram)の重みを可視化します。
    横軸は文書頻度の対数、縦軸は重みの大きさ(または符号付き)を表します。
    これは、頻度の高い単語ほど重みが大きいわけではないことを示しています。
    """
    )
    return


@app.cell
def _(X_train, classes, vec, weights_df):
    df_counts = pd.Series((X_train>0).sum(axis=0).A1, index=vec.get_feature_names_out(), name="doc_freq").reset_index().rename(columns={"index":"term"})
    wdf = weights_df.merge(df_counts, on="term", how="left")
    wdf["doc_freq"] = wdf["doc_freq"].fillna(0).astype(int)
    wdf["log_df"] = np.log10(wdf["doc_freq"]+1)
    w_cls_pick = mo.ui.dropdown(options=list(classes), value=list(classes)[0], label="クラス")
    abs_toggle = mo.ui.switch(label="絶対値で表示", value=True)
    mo.hstack([w_cls_pick, abs_toggle])
    return abs_toggle, w_cls_pick, wdf


@app.cell
def _(abs_toggle, w_cls_pick, wdf):
    sub = wdf[wdf["class"]==w_cls_pick.value].copy()
    sub["y"] = sub["weight"].abs() if abs_toggle.value else sub["weight"]
    w_chart = alt.Chart(sub).mark_text().encode(
        x=alt.X("log_df:Q", title="log10(文書頻度+1)"),
        y=alt.Y("y:Q", title="重み"),
        text="term:N",
        tooltip=["term", "doc_freq", "weight"]
    ).properties(width=500, height=360, title=f"重み{'(絶対値)' if abs_toggle.value else ''}と出現頻度の関係")

    mo.ui.altair_chart(w_chart)
    return


@app.cell
def _():
    topK = mo.ui.slider(10, 80, value=30, step=5, label="対象語数")
    mo.hstack([topK])
    return (topK,)


@app.cell
def _():
    mo.md(
        """
    ## 共起相関行列

    選択したクラスについて、上位の特徴語同士の共起相関をヒートマップで表示します。
    赤が正の相関(一緒に出現しやすい)、青が負の相関(同時に出現しにくい)です。
    これにより、どの単語がセットでモデルに影響を与えているかを理解できます。
    """
    )
    return


@app.cell
def _(X_train, topK, vec, w_cls_pick, weights_df):
    wabs = (weights_df.assign(absw=weights_df["weight"].abs())
            .sort_values(["class","absw"], ascending=[True, False]))
    terms_ = wabs[wabs["class"]==w_cls_pick.value].head(topK.value)["term"].tolist()
    indices = pd.Series(range(len(vec.get_feature_names_out())), index=vec.get_feature_names_out())
    cols = [indices[t] for t in terms_ if t in indices]
    Xm = (X_train[:, cols] > 0).astype("float32")
    C_ = (Xm.T @ Xm).A
    n = Xm.shape[0]
    p = Xm.mean(axis=0).A1
    std = np.sqrt(p*(1-p)+1e-9)
    corr = (C_/n - np.outer(p,p)) / (np.outer(std,std)+1e-9)
    corr_df = pd.DataFrame(corr, index=terms_, columns=terms_).stack().rename("corr").reset_index().rename(columns={"level_0":"t1","level_1":"t2"})
    heat_ = alt.Chart(corr_df).mark_rect().encode(
        x=alt.X("t1:N", sort=terms_),
        y=alt.Y("t2:N", sort=terms_),
        color=alt.Color("corr:Q", scale=alt.Scale(scheme="redblue", domain=[-1,1])),
        tooltip=["t1","t2","corr"]
    ).properties(width=420, height=420, title=f"共起相関({w_cls_pick.value} 上位語)")
    mo.ui.altair_chart(heat_)
    return


@app.cell
def _():
    # cls_select = mo.ui.dropdown(
    #     options=list(classes), value=list(classes)[0], label="クラス(重み)"
    # )
    k_slider = mo.ui.slider(5, 40, value=5, step=5, label="上位語数")
    # mo.hstack([cls_select, k_slider])
    mo.hstack([k_slider])
    return (k_slider,)


@app.cell
def _():
    # dfc = weights_df[weights_df["class"] == cls_select.value]
    # top_pos = dfc.nlargest(k_slider.value, "weight")
    # top_neg = dfc.nsmallest(k_slider.value, "weight")
    # pos_chart = (
    #     alt.Chart(top_pos.assign(dir="pos"))
    #     .mark_bar()
    #     .encode(x=alt.X("weight:Q", title="重み"), y=alt.Y("term:N", sort="-x"))
    #     .properties(width=420, height=420, title="正の寄与")
    # )
    # neg_chart = (
    #     alt.Chart(top_neg.assign(dir="neg"))
    #     .mark_bar()
    #     .encode(x=alt.X("weight:Q", title="重み"), y=alt.Y("term:N", sort="x"))
    #     .properties(width=420, height=420, title="負の寄与")
    # )
    # mo.hstack([mo.ui.altair_chart(pos_chart), mo.ui.altair_chart(neg_chart)])
    return


@app.cell
def _():
    mo.md(
        """
    ## クラスごとの特徴語

    各クラスで最も影響力の大きい(正と負の)特徴語を表示します。
    棒グラフで重みが大きい単語が分かりやすく可視化され、どの単語がクラスに寄与しているかを確認できます。
    """
    )
    return


@app.cell
def _(classes, clf, k_slider, vec):
    html_global = format_as_html(
        eli5.explain_weights(clf, vec=vec, target_names=list(classes), top=(k_slider.value, k_slider.value))
    )
    mo.Html(html_global)
    return


@app.cell
def _():
    mo.md(
        """
    ## 個別予測の説明

    テストデータから特定の文書を選択して、モデルの予測結果を詳しく調べます。
    [ELI5ライブラリ](https://eli5.readthedocs.io/en/latest/overview.html)を使用して、どの単語がどの程度予測に寄与したかを可視化します。
    """
    )
    return


@app.cell
def _(X_test_text):
    idx = mo.ui.slider(0, len(X_test_text) - 1, value=0, step=1, label="予測対象文書の選択(インデックス)")
    mo.hstack([idx])
    return (idx,)


@app.cell
def _(X_test_text, classes, clf, idx, vec, y_test):
    x = X_test_text[idx.value]
    proba = clf.predict_proba(vec.transform([x]))[0]
    pred = classes[np.argmax(proba)]
    table = pd.DataFrame({"class": classes, "prob": proba}).sort_values(
        "prob", ascending=False
    )
    mo.vstack([mo.md(f"**予測:** {pred} (正解:{y_test[idx.value]})"), mo.ui.table(table)])
    return (x,)


@app.cell
def _(classes, clf, vec, x):
    html_local = format_as_html(
        eli5.explain_prediction(clf, x, vec=vec, target_names=list(classes), top=20)
    )
    mo.Html(html_local)
    return


@app.cell
def _(classes, clf, vec, x):
    html_cf = format_as_html(
        eli5.explain_prediction(
            clf, x, vec=vec, target_names=list(classes), top=(20, 20)
        )
    )
    mo.Html(html_cf)
    return


@app.cell
def _():
    mo.md(
        r"""
    **How to read these explanations**

    - Weights show how the model learned to associate words with a class, but words often occur together. Interpret groups, not single weights.
    - Coefficients depend on feature scaling. Compare contributions in a specific text ($\textrm{value} \times \textrm{weight}$), not raw weights across different feature types.
    - Rare words can have large weights yet seldom matter. Check frequency vs. weight and the per-example highlights above.
    """
    )
    return


@app.cell
def _():
    mo.md(
        """
    ## 反事実的分析(What-if分析)

    選択した文書を編集し、テキストの変更が予測確率にどのように影響するかを観察できます。
    これはモデルの動作をより深く理解し、モデルに対する信頼を築くのに役立ちます。
    """
    )
    return


@app.cell
def _(x):
    editor = mo.ui.text_area(
        label="テキスト編集(反事実)",
        value=x,
        full_width=True,
    )
    editor
    return (editor,)


@app.cell
def _(classes, clf, editor, vec, x):
    x2 = editor.value
    p1 = clf.predict_proba(vec.transform([x]))[0]
    p2 = clf.predict_proba(vec.transform([x2]))[0]
    delta = pd.DataFrame(
        {"class": classes, "before": p1, "after": p2, "diff": p2 - p1}
    ).sort_values("after", ascending=False)
    mo.ui.table(delta)
    return


@app.cell
def _():
    mo.md(
        """
    ## 正則化パスの可視化

    正則化係数Cを変化させたときの特徴語の重みの変化をプロットします。
    これにより、どの特徴がロバストで、どの特徴が過学習の可能性があるかを理解できます。
    """
    )
    return


@app.cell
def _(X_train, vec, y_train):
    Cgrid = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0]
    rows = []
    terms = []
    if hasattr(vec, "get_feature_names_out"):
        terms = list(vec.get_feature_names_out())
    terms_pick = terms[:1000]
    pick = terms_pick[:5]
    for C in Cgrid:
        m = LogisticRegression(
            C=C, max_iter=2000, solver="lbfgs"
        ).fit(X_train, y_train)
        W2 = m.coef_
        fn = vec.get_feature_names_out()
        dfW = (
            pd.DataFrame(W2, index=m.classes_, columns=fn)
            .stack()
            .rename("w")
            .reset_index()
            .rename(columns={"level_0": "class", "level_1": "term"})
        )
        rows.append(dfW[dfW["term"].isin(pick)].assign(C=C))
    path_df = (
        pd.concat(rows) if rows else pd.DataFrame(columns=["class", "term", "w", "C"])
    )
    chart = (
        alt.Chart(path_df)
        .mark_line()
        .encode(
            x=alt.X("C:Q", scale=alt.Scale(type="log")),
            y="w:Q",
            color="term:N",
            facet="class:N",
        )
        .properties(width=260, height=160, title="正則化で重みがどう変わるか")
    )
    mo.ui.altair_chart(chart)
    return


@app.cell
def _():
    mo.md(
        """
    ## 信頼性曲線

    モデルの予測確率がどれだけ信頼できるかを評価するための信頼性曲線を表示します。
    理想的には45度線に近い方が良く、予測確率が実際の正解率と一致していることを意味します。
    """
    )
    return


@app.cell
def _(X_test, classes, clf, y_test):
    probs = clf.predict_proba(X_test)
    dfp = pd.DataFrame(probs, columns=classes)
    df_true = pd.Series(y_test, name="true")
    cls_pick = classes[0]
    bins = []
    lines = []
    for cls in classes:
        y_bin = (df_true == cls).astype(int).to_numpy()
        prob_cls = dfp[cls].to_numpy()
        f_obs, f_pred = calibration_curve(y_bin, prob_cls, n_bins=8, strategy="uniform")
        bins.append(
            pd.DataFrame({"class": cls, "mean_pred": f_pred, "empirical": f_obs})
        )
    cal_df = pd.concat(bins)
    base = (
        alt.Chart(cal_df)
        .mark_line()
        .encode(
            x=alt.X("mean_pred:Q", title="平均予測確率 (below x=y is over-confident)"),
            y=alt.Y("empirical:Q", title="実測精度 (above x=y is under-confident)"),
            color="class:N",
        )
        .properties(width=420, height=300, title="信頼性曲線")
        .interactive()
    )
    diag = (
        alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
        .mark_line(strokeDash=[5, 5], color="black")
        .encode(x="x:Q", y="y:Q")
        .interactive()
    )
    combined = base + diag
    mo.ui.altair_chart(combined)
    return


@app.cell
def _(X_test, classes, clf, cls_for_cal, y_test):
    cal_sel_m = cls_for_cal.value if 'cls_for_cal' in globals() else "All"
    cal_proba_matrix_m = clf.predict_proba(X_test)
    cal_df_proba_m = pd.DataFrame(cal_proba_matrix_m, columns=classes)
    cal_true_series_m = pd.Series(y_test, name="true")
    if cal_sel_m == "All":
        cal_class_list_m = list(classes)
    else:
        cal_class_list_m = [cal_sel_m]
    cal_rows_m = []
    for cal_c_m in cal_class_list_m:
        cal_ybin_m = (cal_true_series_m == cal_c_m).astype(int).to_numpy()
        cal_p_m = cal_df_proba_m[cal_c_m].to_numpy()
        cal_fobs_m, cal_fpred_m = calibration_curve(cal_ybin_m, cal_p_m, n_bins=10, strategy="uniform")
        cal_n_m = len(cal_p_m)
        cal_bin_m = pd.cut(cal_p_m, bins=np.linspace(0,1,11), right=False, include_lowest=True)
        cal_group_m = pd.DataFrame({"bin": cal_bin_m, "p": cal_p_m, "y": cal_ybin_m}).groupby("bin")
        cal_ece_m = (cal_group_m.apply(lambda g: abs(g["y"].mean() - g["p"].mean()) * len(g) / cal_n_m)).sum()
        cal_rows_m.append({"class": cal_c_m, "ECE": cal_ece_m})
    cal_ece_df_m = pd.DataFrame(cal_rows_m).sort_values("ECE").reset_index(drop=True)
    cal_class_show_m = cal_class_list_m[0]
    cal_hist_chart_m = alt.Chart(pd.DataFrame({"p": cal_df_proba_m[cal_class_show_m]})).mark_bar().encode(
        x=alt.X("p:Q", bin=alt.Bin(maxbins=20), title=f"予測確率 p({cal_class_show_m})"),
        y=alt.Y("count()", title="件数")
    ).properties(width=360, height=220, title="信頼度ヒストグラム")
    mo.hstack([mo.ui.table(cal_ece_df_m), mo.ui.altair_chart(cal_hist_chart_m)])
    return


@app.cell
def _():
    mo.md(
        """
    ## 混同行列の詳細分析

    特定の真のクラスと予測クラスの組み合わせについて、実際の予測例を確認できます。
    これにより、モデルがどのような誤分類をしているかを具体的に観察できます。
    """
    )
    return


@app.cell
def _(X_test_text, classes, y_pred, y_test):
    df_show = pd.DataFrame({"text": X_test_text, "true": y_test, "pred": y_pred})
    sel_true = mo.ui.dropdown(
        options=list(classes), value=list(classes)[0], label="True"
    )
    sel_pred = mo.ui.dropdown(
        options=list(classes), value=list(classes)[0], label="Pred"
    )
    mo.hstack([sel_true, sel_pred])
    return df_show, sel_pred, sel_true


@app.cell
def _(df_show, sel_pred, sel_true):
    subset = df_show[
        (df_show["true"] == sel_true.value) & (df_show["pred"] == sel_pred.value)
    ].reset_index(drop=True)
    mo.ui.table(subset)
    return


@app.cell
def _():
    return


@app.cell
def _():
    return


if __name__ == "__main__":
    app.run()