Grafos Avanzadosunion-findDSUconjuntos-disjuntosgrafos

Union-Find (Disjoint Set Union)

Domina la estructura Union-Find para conjuntos disjuntos y problemas de conectividad

OOI Oaxaca9 de febrero de 20266 min read

¿Qué es Union-Find?

Union-Find (o Disjoint Set Union - DSU) es una estructura de datos que maneja conjuntos disjuntos (que no se sobreponen) con dos operaciones:

  • Find(x): ¿A qué conjunto pertenece x?
  • Union(x, y): Unir los conjuntos de x e y

Complejidad

Con las optimizaciones de compresión de caminos y unión por rango:

OperaciónComplejidad
FindO(α(n))O(\alpha(n))O(1)O(1)
UnionO(α(n))O(\alpha(n))O(1)O(1)

Donde α\alpha es la inversa de Ackermann (crece extremadamente lento, < 5 para cualquier n práctico).

Implementación básica

const int MAXN = 100005;
int padre[MAXN];
int rango[MAXN];

void init(int n) {
    for (int i = 0; i <= n; i++) {
        padre[i] = i;      // Cada elemento es su propio padre
        rango[i] = 0;      // Altura inicial 0
    }
}

// Find con compresión de caminos
int find(int x) {
    if (padre[x] != x) {
        padre[x] = find(padre[x]);  // Compresión
    }
    return padre[x];
}

// Union por rango
void unite(int x, int y) {
    int px = find(x);
    int py = find(y);

    if (px == py) return;  // Ya están en el mismo conjunto

    // Unir el árbol más pequeño al más grande
    if (rango[px] < rango[py]) swap(px, py);
    padre[py] = px;
    if (rango[px] == rango[py]) rango[px]++;
}

// Verificar si están en el mismo conjunto
bool conectados(int x, int y) {
    return find(x) == find(y);
}

Visualización

Inicial: {1} {2} {3} {4} {5}

Union(1, 2):
    1          3   4   5
    |
    2

Union(3, 4):
    1          3      5
    |          |
    2          4

Union(1, 3):
       1           5
      / \
     2   3
         |
         4

Con compresión de caminos, find(4) aplana:
       1           5
     / | \
    2  3  4

Union por tamaño

Alternativa a union por rango:

int padre[MAXN];
int tamano[MAXN];

void init(int n) {
    for (int i = 0; i <= n; i++) {
        padre[i] = i;
        tamano[i] = 1;
    }
}

int find(int x) {
    if (padre[x] != x) {
        padre[x] = find(padre[x]);
    }
    return padre[x];
}

void unite(int x, int y) {
    int px = find(x);
    int py = find(y);

    if (px == py) return;

    // Unir el más pequeño al más grande
    if (tamano[px] < tamano[py]) swap(px, py);
    padre[py] = px;
    tamano[px] += tamano[py];
}

int getTamano(int x) {
    return tamano[find(x)];
}

Aplicaciones

1. Componentes conexas

int contarComponentes(int n) {
    int componentes = 0;
    for (int i = 1; i <= n; i++) {
        if (find(i) == i) {
            componentes++;
        }
    }
    return componentes;
}

2. Detectar ciclos en grafo no dirigido

bool tieneCiclo(vector<pair<int,int>>& aristas) {
    for (auto [u, v] : aristas) {
        if (find(u) == find(v)) {
            return true;  // Ya conectados → agregar arista crea ciclo
        }
        unite(u, v);
    }
    return false;
}

3. Kruskal's MST

struct Arista {
    int u, v, peso;
    bool operator<(const Arista& o) const {
        return peso < o.peso;
    }
};

long long kruskal(int n, vector<Arista>& aristas) {
    sort(aristas.begin(), aristas.end());
    init(n);

    long long costoTotal = 0;
    int aristasUsadas = 0;

    for (auto& e : aristas) {
        if (find(e.u) != find(e.v)) {
            unite(e.u, e.v);
            costoTotal += e.peso;
            aristasUsadas++;

            if (aristasUsadas == n - 1) break;
        }
    }

    return aristasUsadas == n - 1 ? costoTotal : -1;  // -1 si no es conexo
}

4. Unión con información adicional

int padre[MAXN];
int minimo[MAXN];  // Mínimo en cada conjunto
int maximo[MAXN];  // Máximo en cada conjunto
int tamano[MAXN];

void init(int n) {
    for (int i = 0; i <= n; i++) {
        padre[i] = i;
        minimo[i] = i;
        maximo[i] = i;
        tamano[i] = 1;
    }
}

void unite(int x, int y) {
    int px = find(x);
    int py = find(y);

    if (px == py) return;

    if (tamano[px] < tamano[py]) swap(px, py);
    padre[py] = px;
    tamano[px] += tamano[py];
    minimo[px] = min(minimo[px], minimo[py]);
    maximo[px] = max(maximo[px], maximo[py]);
}

Union-Find con rollback

Para deshacer operaciones:

int padre[MAXN];
int rango[MAXN];
stack<pair<int*, int>> historial;

int find(int x) {
    while (padre[x] != x) x = padre[x];
    return x;
}

void unite(int x, int y) {
    int px = find(x);
    int py = find(y);

    if (px == py) return;

    if (rango[px] < rango[py]) swap(px, py);

    historial.push({&padre[py], padre[py]});
    padre[py] = px;

    if (rango[px] == rango[py]) {
        historial.push({&rango[px], rango[px]});
        rango[px]++;
    }
}

void rollback() {
    while (!historial.empty()) {
        auto [ptr, val] = historial.top();
        historial.pop();
        *ptr = val;
    }
}

void saveCheckpoint() {
    historial.push({nullptr, 0});  // Marcador
}

void rollbackToCheckpoint() {
    while (!historial.empty() && historial.top().first != nullptr) {
        auto [ptr, val] = historial.top();
        historial.pop();
        *ptr = val;
    }
    if (!historial.empty()) historial.pop();  // Quitar marcador
}

Weighted Union-Find

Para relaciones con pesos entre elementos:

int padre[MAXN];
long long dist[MAXN];  // Distancia al padre

int find(int x) {
    if (padre[x] == x) return x;

    int root = find(padre[x]);
    dist[x] += dist[padre[x]];  // Actualizar distancia
    padre[x] = root;
    return root;
}

// Agregar relación: dist[y] - dist[x] = d
bool relate(int x, int y, long long d) {
    int px = find(x);
    int py = find(y);

    if (px == py) {
        return dist[x] - dist[y] == d;  // Verificar consistencia
    }

    // dist[px] = dist[x] + nuevo_dist[px]
    // dist[py] = dist[y]
    // Queremos: dist[y] - dist[x] = d
    padre[px] = py;
    dist[px] = dist[y] - dist[x] + d;

    return true;
}

long long getDist(int x, int y) {
    if (find(x) != find(y)) return LLONG_MAX;  // No relacionados
    return dist[x] - dist[y];
}

Template completo

#include <bits/stdc++.h>
using namespace std;

const int MAXN = 200005;
int parent[MAXN];
int rank_[MAXN];
int size_[MAXN];

void init(int n) {
    for (int i = 0; i <= n; i++) {
        parent[i] = i;
        rank_[i] = 0;
        size_[i] = 1;
    }
}

int find(int x) {
    if (parent[x] != x) {
        parent[x] = find(parent[x]);
    }
    return parent[x];
}

bool unite(int x, int y) {
    int px = find(x);
    int py = find(y);

    if (px == py) return false;

    if (rank_[px] < rank_[py]) swap(px, py);
    parent[py] = px;
    size_[px] += size_[py];
    if (rank_[px] == rank_[py]) rank_[px]++;

    return true;
}

bool connected(int x, int y) {
    return find(x) == find(y);
}

int getSize(int x) {
    return size_[find(x)];
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, q;
    cin >> n >> q;

    init(n);

    while (q--) {
        int tipo, a, b;
        cin >> tipo >> a >> b;

        if (tipo == 0) {
            unite(a, b);
        } else {
            cout << (connected(a, b) ? 1 : 0) << "\n";
        }
    }

    return 0;
}

Ejercicios recomendados

  1. CSES - Road Reparation
  2. CSES - Road Construction
  3. LeetCode - Number of Islands II
  4. Codeforces - Roads not only in Berland
  5. AtCoder - Union Find