1. 前言

回想起春招实习面阿里时,面试官看到我简历上写着了解 Netty,就随便问了一个“简单” 的问题:

你知道 Netty 是怎么优化 ThreadLocal 的吗(灵魂拷问)?

尽管之前看过 Netty 的基本使用,也尝试撸过一些小的框架,对 Netty 有最初步的了解,但是从未通说过优化 ThreadLocal。碰巧最近再看 Netty 源码,所以对这个问题做出解答。不过在此之前,先来看下大家应该都使用过的 JDK 内置的 ThreadLocal。

2. ThreadLocal

大家应该多少了解过 JDK 内置的 ThreadLocal,它提供了一种无锁的方式到达线程安全和隔离的效果。在这里我们不多赘述,简单回忆一下 ThreadLocal 的实现原理:

  1. ThreadLocal 就是由每个线程 Thread 类维护一个 ThreadLocal#ThreadLocalMap
  2. 其键为当前 ThreadLocal 实例,值为每个线程隔离的变量值
  3. 这样就能实现线程变量隔离

以下是一个 get 操作的时序图:

ThreadLocal#get 时序图.png

Netty 为其提供了一个更快的替代类 —— FastThreadLocal,那到底快在哪儿呢?

3. 更快的ThreadLocal

回想一下,ThreadLocal#ThreadLocalMap 通过哈希表存储 ThreadLocal 实例与隔离变量值的映射,并用线性探测法解决哈希冲突。既然存在冲突,当然就有性能上的问题,当一个 KEY 发现 HashCode 所对应的槽下标已经存在元素,将会线性查找下一个为为空的槽,严重的话,如果整个数组都冲突则可能导致死循环(当然概率很小)。

而 FastThreadLocal 则摒弃哈希表的数据结构,提供一种更高效的方式来建立这种映射。

InternalThreadLocalMap

前面提到的更高效的方式建立这种映射,答案就在 InternalThreadLocalMap 类内:

  • 它提供一种每个 FastThreadLocal 对应的 index 与当前线程的隔离变量建立映射
  • 但是每个 FastThreadLocal 在初始化时,都要先向 InternalThreadLocalMap 申请一个 index;
  • 这样就能通过 index 获取到当前线程在该 FastThreadLocal 存储的值。

先来看下 InternalThreadLocalMap 继承的 父类 UnpaddedInternalThreadLocalMap 定义,除了定义了需要使用的对象,没有定义任何方法:

class UnpaddedInternalThreadLocalMap {

    // 存放每个线程对应的 InternalThreadLocalMap,如果是 FastThreadLocalThread 则不从这里获取
    static final ThreadLocal<InternalThreadLocalMap> slowThreadLocalMap = new ThreadLocal<InternalThreadLocalMap>();
    
    // 原子类型递增对象,用来生成下一个 FastThreadLocal 对应的 index
    static final AtomicInteger nextIndex = new AtomicInteger();
    
    // 用来建立FastThreadLocal对应的index -> 线程隔离变量的映射
    Object[] indexedVariables;
    
    // ...... 其他成员变量

    UnpaddedInternalThreadLocalMap(Object[] indexedVariables) {
        this.indexedVariables = indexedVariables;
    }
}

而在 InternalThreadLocalMap 则实现了映射关系的建立,以及获取隔离变量的逻辑。

建立映射

建立映射的逻辑很简单,就是根据传入的下标,直接将值存储到数组即可,必要时需要扩容数组的长度:

public boolean setIndexedVariable(int index, Object value) {
    Object[] lookup = indexedVariables;
    if (index < lookup.length) {
        Object oldValue = lookup[index];
        lookup[index] = value;
        return oldValue == UNSET;
    } else {
        // 对 indexedVariables 进行扩容
        expandIndexedVariableTableAndSet(index, value);
        return true;
    }
}

获取隔离变量

获取隔离变量更简单,根据传入 FastThreadLocal 所关联的 index 下标,索引到数组的值返回即可;如果越界,则返回 UNSET

public static final Object UNSET = new Object();

public Object indexedVariable(int index) {
    Object[] lookup = indexedVariables;
    return index < lookup.length? lookup[index] : UNSET;
}

生成index

同时,我们还生成下一个 FastThreadLocal 示例在 indexedVariables 数组内对应的下标 index:

public static int nextVariableIndex() {
    int index = nextIndex.getAndIncrement();
    if (index < 0) {
        nextIndex.decrementAndGet();
        throw new IllegalStateException("too many thread-local indexed variables");
    }
    return index;
}

获取每个线程的InternalThreadLocalMap

那么问题来了,InternalThreadLocalMap 对象是存储在哪里呢? FastThreadLocal 是定位到每个线程维护的 InternalThreadLocalMap 怎么获取的呢?InternalThreadLocalMap 内部提供了一个静态工具类方法,提供返回当前线程对应的 InternalThreadLocalMap 实例的方法:

public static InternalThreadLocalMap get() {
    Thread thread = Thread.currentThread();
    // 根据不同的线程采取不同的获取方式
    if (thread instanceof FastThreadLocalThread) {
        return fastGet((FastThreadLocalThread) thread);
    } else {
        return slowGet();
    }
}

可以看到,根据不同的 Thread 类型,采取不同的获取方式,说明它们的存储方式也不一样:

  • 普通的 Thread 类:存储在 UnpaddedInternalThreadLocalMap 类的 ThreadLocal 当中;
  • Netty 扩展的 FastThreadLocalThread 类:存储在成员变量的 threadLocalMap 属性中。

所以,这就很明白了:

// 针对 FastThreadLocalThread 获取 InternalThreadLocalMap 的方式
// 直接从获取其成员变量 threadLocalMap
private static InternalThreadLocalMap fastGet(FastThreadLocalThread thread) {
    InternalThreadLocalMap threadLocalMap = thread.threadLocalMap();
    // 如果为空,则初始化
    if (threadLocalMap == null) {
        threadLocalMap = new InternalThreadLocalMap(); 
        thread.setThreadLocalMap(threadLocalMap);
    }
    return threadLocalMap;
}

// 针对普通 Thread 获取 InternalThreadLocalMap 的方式
// 从 UnpaddedInternalThreadLocalMap 的 ThreadLocal 中获取
private static InternalThreadLocalMap slowGet() {
    ThreadLocal<InternalThreadLocalMap> slowThreadLocalMap = UnpaddedInternalThreadLocalMap.slowThreadLocalMap;
    InternalThreadLocalMap ret = slowThreadLocalMap.get();
    // 如果为空,则进行初始化
    if (ret == null) {
        ret = new InternalThreadLocalMap();
        slowThreadLocalMap.set(ret);
    }
    return ret;
}

看到这儿,基本就了解了 InternalThreadLocalMap 核心结构,它由每个线程单独存储,并维护一个数组建立 FastThreadLocal 实例与隔离变量值的映射。在 FastThreadLocal 中,就是对 InternalThreadLocalMap 进一步调用而已。

FastThreadLocal

前面提到,每个 FastThreadLocal 都需要对应一个 index,存储在成员变量中:

public class FastThreadLocal<V> {
    
    private final int index;
    
    public FastThreadLocal() {
        // 在初始化的时候向 InternalThreadLocalMap 申请一个 index
        index = InternalThreadLocalMap.nextVariableIndex();
    }
}

获取值时,先获取每个线程存储的 InternalThreadLocalMap,然后取出该 FastThreadLocal 实例在当前线程存储的值:

public final V get() {
    InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.get();
    // 根据 index 获取到每个线程单独存储的值
    Object v = threadLocalMap.indexedVariable(index);
    if (v != InternalThreadLocalMap.UNSET) {
        return (V) v;
    }
    return initialize(threadLocalMap);
}

存储值时,同样先获取每个线程存储的 InternalThreadLocalMap,然后将值存储到指定下标即可:

public final void set(V value) {
    if (value != InternalThreadLocalMap.UNSET) {
        InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.get();
        setKnownNotUnset(threadLocalMap, value);
    } else {
        remove();
    }
}

private void setKnownNotUnset(InternalThreadLocalMap threadLocalMap, V value) {
    if (threadLocalMap.setIndexedVariable(index, value)) {
        addToVariablesToRemove(threadLocalMap, this);
    }
}

下列是 FastThreadLocal 的 get 方法时序图:

FastThreadLocal#get时序图.png

4. 总结

从上面的源码可以看到,如果我们使用的是 Thread,依然会在 ThreadLocal 寻找存储的 InternalThreadLocalMap ;反之,如果我们只用的 FastThreadLocal,则彻底脱离了 ThreadLocal 的弊端。因此正确使用姿势应该是:

public class FastThreadLocalDemo {

    private static final FastThreadLocal<Integer> threadLocal = new FastThreadLocal<>();

    public static void main(String[] args)  {
        for (int i = 0; i < 3; i++) {
            // 使用 FastThreadLocalThread 启动线程
            new FastThreadLocalThread(() -> {
                try {
                    threadLocal.set(new Random().nextInt());
                    Thread.sleep(new Random().nextInt(3) * 1000);
                    System.out.println(threadLocal.get());
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                threadLocal.remove();
            }).start();
        }
    }
}

之所以称 FastThreadLocalThread 快呢,主要就是因为没有使用 JDK-ThreadLocal 所使用的哈希表结果存储数据,不会发生哈希冲突并通过线性探测法耗时解决冲突

PS. 另外,该问对 FastThreadLocal 还有很多细节没有讲解,因为时间的原因就先说到这儿了,如果有机会再补充。