docs/E2E_COREFERENCE.md

## Phase 2: End-to-End Span-Based Coreference Resolution

This document describes the end-to-end (E2E) span-based coreference resolution system implemented in Phase 2. This architecture jointly learns mention detection and coreference resolution, achieving higher accuracy than the pipelined Phase 1 approach.

## Overview

### Key Differences from Phase 1

| Aspect | Phase 1 (Pipelined) | Phase 2 (End-to-End) |
|--------|---------------------|----------------------|
| Architecture | Two separate stages | Single joint model |
| Mention Detection | Rule-based, pre-defined | Learned span scoring |
| Optimization | Stages optimized separately | Joint end-to-end optimization |
| Error Propagation | Detection errors → resolution errors | No error propagation |
| Expected F1 | 75-80% | 82-85% |
| Training Time | ~2-3 hours | ~4-6 hours |
| Inference Speed | ~50-100ms/doc | ~80-120ms/doc |

### Architecture Diagram

```
Text → Token Embeddings → BiLSTM Encoder
                             ↓
                    Span Enumeration (all possible spans)
                             ↓
                    Span Scoring (mention detection head)
                             ↓
                    Top-K Pruning (keep best spans)
                             ↓
                    Pairwise Scoring (coreference head)
                             ↓
                    Clustering → Coreference Chains
```

## Quick Start

### Training

```bash
mix nasty.train.e2e_coref \
  --corpus data/ontonotes/train \
  --dev data/ontonotes/dev \
  --output priv/models/en/e2e_coref \
  --epochs 25 \
  --batch-size 16 \
  --learning-rate 0.0005
```

### Evaluation

```bash
mix nasty.eval.e2e_coref \
  --model priv/models/en/e2e_coref \
  --test data/ontonotes/test \
  --baseline
```

The `--baseline` flag compares E2E results with Phase 1 models.

### Usage in Code

```elixir
# Load trained E2E models
{:ok, models, params, vocab} = 
  Nasty.Semantic.Coreference.Neural.E2ETrainer.load_models(
    "priv/models/en/e2e_coref"
  )

# Resolve coreferences in a document
{:ok, resolved_doc} = 
  Nasty.Semantic.Coreference.Neural.E2EResolver.resolve(
    document, 
    models, 
    params, 
    vocab
  )

# Or use auto-loading convenience function
{:ok, resolved_doc} = 
  Nasty.Semantic.Coreference.Neural.E2EResolver.resolve_auto(
    document,
    "priv/models/en/e2e_coref"
  )

# Access coreference chains
chains = resolved_doc.coref_chains
```

## Model Components

### 1. Span Enumeration (`SpanEnumeration`)

Generates all possible spans up to a maximum length, then prunes to top-K candidates.

**Key Functions:**
- `enumerate_spans/2` - Generate all spans up to max_length
- `enumerate_and_prune/2` - Score and prune to top-K
- `span_representation/4` - Compute span embedding

**Span Representation:**
```
span_repr = [start_state, end_state, attention_over_span, width_embedding]
```

**Configuration:**
- `max_length`: 10 tokens (default)
- `top_k`: 50 spans per sentence (default)

### 2. Span Model (`SpanModel`)

Joint architecture with shared encoder and two task-specific heads.

**Components:**
- **Shared Encoder**: BiLSTM (256 hidden units) processes entire document
- **Span Scorer Head**: Feedforward network [256, 128] → Sigmoid (mention detection)
- **Pair Scorer Head**: Feedforward network [512, 256] → Sigmoid (coreference)

**Loss Function:**
```
total_loss = 0.3 * span_loss + 0.7 * coref_loss
```

Both use binary cross-entropy.

### 3. E2E Trainer (`E2ETrainer`)

Training pipeline with joint optimization and early stopping.

**Training Process:**
1. Load training and dev data (OntoNotes format)
2. Build vocabulary from all documents
3. Initialize models with random weights
4. Train for N epochs with Adam optimizer
5. Evaluate on dev set after each epoch
6. Early stopping when dev F1 stops improving

**Hyperparameters:**
- Epochs: 25
- Batch size: 16
- Learning rate: 0.0005 (lower than Phase 1)
- Dropout: 0.3
- Patience: 3 epochs
- Span loss weight: 0.3
- Coref loss weight: 0.7

### 4. E2E Resolver (`E2EResolver`)

Inference using trained models.

**Resolution Steps:**
1. Extract tokens from document AST
2. Convert tokens to IDs using vocabulary
3. Encode with BiLSTM
4. Enumerate and score candidate spans
5. Filter spans by score threshold (default: 0.5)
6. Score all span pairs for coreference
7. Build chains using greedy clustering
8. Attach chains to document

**Clustering Algorithm:**
Greedy left-to-right antecedent selection:
- For each span, find best previous span with score > threshold
- Merge spans into same cluster if score exceeds threshold
- Results in transitively-closed coreference chains

## Data Preparation

### Span Training Data

The E2E model requires two types of training data:

**1. Span Detection Data** (`create_span_training_data/2`):
- Generates (span, label) pairs
- Label = 1 if span is a mention, 0 otherwise
- Enumerates all spans up to max_width
- Samples negative spans at 3:1 ratio

**2. Antecedent Data** (`create_antecedent_data/2`):
- Generates (mention, antecedent, label) triples
- Label = 1 if coreferent, 0 otherwise
- Considers previous N mentions as antecedent candidates
- Samples negative antecedents at 1.5:1 ratio

## Training Options

All training options with defaults:

```bash
--corpus <path>              # Required
--dev <path>                 # Required
--output <path>              # Required
--epochs 25                  # Training epochs
--batch-size 16              # Batch size
--learning-rate 0.0005       # Learning rate
--hidden-dim 256             # LSTM hidden dimension
--dropout 0.3                # Dropout rate
--patience 3                 # Early stopping patience
--max-span-width 10          # Maximum span width
--top-k-spans 50             # Spans to keep per sentence
--span-loss-weight 0.3       # Weight for span detection
--coref-loss-weight 0.7      # Weight for coreference
```

## Evaluation Options

```bash
--model <path>               # Required
--test <path>                # Required
--baseline                   # Compare with Phase 1
--max-span-length 10         # Maximum span length
--top-k-spans 50             # Top K spans to keep
--min-span-score 0.5         # Minimum span score
--min-coref-score 0.5        # Minimum coref score
```

## Performance

### Expected Results

**E2E Model (Phase 2):**
- CoNLL F1: 82-85%
- MUC F1: 83-86%
- B³ F1: 80-83%
- CEAF F1: 82-85%

**Improvement over Phase 1:**
- +5-10 F1 points overall
- Better recall on singletons
- Fewer spurious mention detections
- More accurate pronoun resolution

### Speed Benchmarks

- Training: ~4-6 hours on OntoNotes (single GPU)
- Encoding: ~80 mentions/sec
- Span enumeration: ~500 spans/sec
- Pairwise scoring: ~800 pairs/sec
- End-to-end: ~80-120ms per document

## Advantages of E2E Approach

1. **No Error Propagation**: Mention detection errors don't affect coreference
2. **Joint Optimization**: Both tasks optimized together for best overall performance
3. **Learned Mention Detection**: Model learns what constitutes a mention
4. **Better Boundaries**: Span enumeration finds correct mention boundaries
5. **Global Context**: BiLSTM encoder captures document-level context

## Limitations

1. **Computational Cost**: Enumerating all spans is expensive
2. **Memory Usage**: Requires storing representations for all candidate spans
3. **Max Span Length**: Limited to spans of 10 tokens or less
4. **Pruning Errors**: Top-K pruning may discard valid mentions

## Troubleshooting

### Out of Memory

- Reduce `--batch-size` to 8
- Reduce `--top-k-spans` to 30
- Reduce `--hidden-dim` to 128

### Low Span Detection

- Increase `--span-loss-weight` to 0.5
- Lower `--min-span-score` to 0.3
- Increase `--top-k-spans` to 100

### Low Coreference Accuracy

- Increase `--coref-loss-weight` to 0.8
- Train for more epochs: `--epochs 30`
- Add more training data

### Slow Training

- Increase `--batch-size` to 32 (if memory allows)
- Reduce `--top-k-spans` to 30
- Use GPU acceleration (EXLA)

## Module Reference

### SpanEnumeration

```elixir
enumerate_and_prune(lstm_outputs, opts) :: {:ok, [span()]}
enumerate_spans(lstm_outputs, max_length) :: [{start, end}]
span_representation(lstm_outputs, start, end, width_emb) :: tensor
build_span_scorer(opts) :: Axon.t()
```

### SpanModel

```elixir
build_model(opts) :: models()
build_encoder(vocab_size, embed_dim, hidden_dim, dropout) :: Axon.t()
build_span_scorer(span_dim, hidden_layers, dropout) :: Axon.t()
build_pair_scorer(pair_dim, hidden_layers, dropout) :: Axon.t()
extract_pair_features(span1, span2, tokens) :: tensor
forward(models, params, token_ids, spans) :: {span_scores, coref_scores}
compute_loss(span_scores, coref_scores, gold_span_labels, gold_coref_labels, opts) :: scalar
```

### E2ETrainer

```elixir
train(train_data, dev_data, vocab, opts) :: {:ok, models(), params(), history()}
save_models(models, params, vocab, path) :: :ok
load_models(path) :: {:ok, models(), params(), vocab()}
```

### E2EResolver

```elixir
resolve(document, models, params, vocab, opts) :: {:ok, Document.t()}
resolve_auto(document, model_path, opts) :: {:ok, Document.t()}
```

## References

- Lee et al. (2017). "End-to-end Neural Coreference Resolution"
- Lee et al. (2018). "Higher-order Coreference Resolution with Coarse-to-fine Inference"
- Joshi et al. (2019). "SpanBERT: Improving Pre-training by Representing and Predicting Spans"

## See Also

- [NEURAL_COREFERENCE.md](NEURAL_COREFERENCE.md) - Phase 1 documentation
- [COREFERENCE_TRAINING.md](COREFERENCE_TRAINING.md) - Detailed training guide
- [API.md](API.md) - Full API reference