技术标签: pytorch
pytorch-Field 源码:text/torchtext/data/field.py
【转】如何利用torchtext读取json文件并生成batch 作者:Geek Fly --> 这篇文章挺好的
# coding: utf8
from collections import Counter, OrderedDict
from itertools import chain
import six
import torch
from tqdm import tqdm
from .dataset import Dataset
from .pipeline import Pipeline
from .utils import get_tokenizer, dtype_to_attr, is_tokenizer_serializable
from ..vocab import Vocab, SubwordVocab
class RawField(object):
""" Defines a general datatype.
Every dataset consists of one or more types of data. For instance, a text
classification dataset contains sentences and their classes, while a
machine translation dataset contains paired examples of text in two
languages. Each of these types of data is represented by a RawField object.
A RawField object does not assume any property of the data type and
it holds parameters relating to how a datatype should be processed.
Attributes:
preprocessing: The Pipeline that will be applied to examples
using this field before creating an example.
Default: None.
postprocessing: A Pipeline that will be applied to a list of examples
using this field before assigning to a batch.
Function signature: (batch(list)) -> object
Default: None.
is_target: Whether this field is a target variable.
Affects iteration over batches. Default: False
"""
def __init__(self, preprocessing=None, postprocessing=None, is_target=False):
self.preprocessing = preprocessing
self.postprocessing = postprocessing
self.is_target = is_target
def preprocess(self, x):
""" Preprocess an example if the `preprocessing` Pipeline is provided. """
if self.preprocessing is not None:
return self.preprocessing(x)
else:
return x
def process(self, batch, *args, **kwargs):
""" Process a list of examples to create a batch.
Postprocess the batch with user-provided Pipeline.
Args:
batch (list(object)): A list of object from a batch of examples.
Returns:
object: Processed object given the input and custom
postprocessing Pipeline.
"""
if self.postprocessing is not None:
batch = self.postprocessing(batch)
return batch
class Field(RawField):
"""Defines a datatype together with instructions for converting to Tensor.
Field class models common text processing datatypes that can be represented
by tensors. It holds a Vocab object that defines the set of possible values
for elements of the field and their corresponding numerical representations.
The Field object also holds other parameters relating to how a datatype
should be numericalized, such as a tokenization method and the kind of
Tensor that should be produced.
If a Field is shared between two columns in a dataset (e.g., question and
answer in a QA dataset), then they will have a shared vocabulary.
Attributes:
sequential: Whether the datatype represents sequential data. If False,
no tokenization is applied. Default: True.
use_vocab: Whether to use a Vocab object. If False, the data in this
field should already be numerical. Default: True.
init_token: A token that will be prepended to every example using this
field, or None for no initial token. Default: None.
eos_token: A token that will be appended to every example using this
field, or None for no end-of-sentence token. Default: None.
fix_length: A fixed length that all examples using this field will be
padded to, or None for flexible sequence lengths. Default: None.
dtype: The torch.dtype class that represents a batch of examples
of this kind of data. Default: torch.long.
preprocessing: The Pipeline that will be applied to examples
using this field after tokenizing but before numericalizing. Many
Datasets replace this attribute with a custom preprocessor.
Default: None.
postprocessing: A Pipeline that will be applied to examples using
this field after numericalizing but before the numbers are turned
into a Tensor. The pipeline function takes the batch as a list, and
the field's Vocab.
Default: None.
lower: Whether to lowercase the text in this field. Default: False.
tokenize: The function used to tokenize strings using this field into
sequential examples. If "spacy", the SpaCy tokenizer is
used. If a non-serializable function is passed as an argument,
the field will not be able to be serialized. Default: string.split.
tokenizer_language: The language of the tokenizer to be constructed.
Various languages currently supported only in SpaCy.
include_lengths: Whether to return a tuple of a padded minibatch and
a list containing the lengths of each examples, or just a padded
minibatch. Default: False.
batch_first: Whether to produce tensors with the batch dimension first.
Default: False.
pad_token: The string token used as padding. Default: "<pad>".
unk_token: The string token used to represent OOV words. Default: "<unk>".
pad_first: Do the padding of the sequence at the beginning. Default: False.
truncate_first: Do the truncating of the sequence at the beginning. Default: False
stop_words: Tokens to discard during the preprocessing step. Default: None
is_target: Whether this field is a target variable.
Affects iteration over batches. Default: False
"""
vocab_cls = Vocab
# Dictionary mapping PyTorch tensor dtypes to the appropriate Python
# numeric type.
dtypes = {
torch.float32: float,
torch.float: float,
torch.float64: float,
torch.double: float,
torch.float16: float,
torch.half: float,
torch.uint8: int,
torch.int8: int,
torch.int16: int,
torch.short: int,
torch.int32: int,
torch.int: int,
torch.int64: int,
torch.long: int,
}
ignore = ['dtype', 'tokenize']
def __init__(self, sequential=True, use_vocab=True, init_token=None,
eos_token=None, fix_length=None, dtype=torch.long,
preprocessing=None, postprocessing=None, lower=False,
tokenize=None, tokenizer_language='en', include_lengths=False,
batch_first=False, pad_token="<pad>", unk_token="<unk>",
pad_first=False, truncate_first=False, stop_words=None,
is_target=False):
self.sequential = sequential
self.use_vocab = use_vocab
self.init_token = init_token
self.eos_token = eos_token
self.unk_token = unk_token
self.fix_length = fix_length
self.dtype = dtype
self.preprocessing = preprocessing
self.postprocessing = postprocessing
self.lower = lower
# store params to construct tokenizer for serialization
# in case the tokenizer isn't picklable (e.g. spacy)
self.tokenizer_args = (tokenize, tokenizer_language)
self.tokenize = get_tokenizer(tokenize, tokenizer_language)
self.include_lengths = include_lengths
self.batch_first = batch_first
self.pad_token = pad_token if self.sequential else None
self.pad_first = pad_first
self.truncate_first = truncate_first
try:
self.stop_words = set(stop_words) if stop_words is not None else None
except TypeError:
raise ValueError("Stop words must be convertible to a set")
self.is_target = is_target
def __getstate__(self):
str_type = dtype_to_attr(self.dtype)
if is_tokenizer_serializable(*self.tokenizer_args):
tokenize = self.tokenize
else:
# signal to restore in `__setstate__`
tokenize = None
attrs = {k: v for k, v in self.__dict__.items() if k not in self.ignore}
attrs['dtype'] = str_type
attrs['tokenize'] = tokenize
return attrs
def __setstate__(self, state):
state['dtype'] = getattr(torch, state['dtype'])
if not state['tokenize']:
state['tokenize'] = get_tokenizer(*state['tokenizer_args'])
self.__dict__.update(state)
def __hash__(self):
# we don't expect this to be called often
return 42
def __eq__(self, other):
if not isinstance(other, RawField):
return False
return self.__dict__ == other.__dict__
def preprocess(self, x):
"""Load a single example using this field, tokenizing if necessary.
If the input is a Python 2 `str`, it will be converted to Unicode
first. If `sequential=True`, it will be tokenized. Then the input
will be optionally lowercased and passed to the user-provided
`preprocessing` Pipeline."""
if (six.PY2 and isinstance(x, six.string_types)
and not isinstance(x, six.text_type)):
x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x)
if self.sequential and isinstance(x, six.text_type):
x = self.tokenize(x.rstrip('\n'))
if self.lower:
x = Pipeline(six.text_type.lower)(x)
if self.sequential and self.use_vocab and self.stop_words is not None:
x = [w for w in x if w not in self.stop_words]
if self.preprocessing is not None:
return self.preprocessing(x)
else:
return x
def process(self, batch, device=None):
""" Process a list of examples to create a torch.Tensor.
Pad, numericalize, and postprocess a batch and create a tensor.
Args:
batch (list(object)): A list of object from a batch of examples.
Returns:
torch.autograd.Variable: Processed object given the input
and custom postprocessing Pipeline.
"""
padded = self.pad(batch)
tensor = self.numericalize(padded, device=device)
return tensor
def pad(self, minibatch):
"""Pad a batch of examples using this field.
Pads to self.fix_length if provided, otherwise pads to the length of
the longest example in the batch. Prepends self.init_token and appends
self.eos_token if those attributes are not None. Returns a tuple of the
padded list and a list containing lengths of each example if
`self.include_lengths` is `True` and `self.sequential` is `True`, else just
returns the padded list. If `self.sequential` is `False`, no padding is applied.
"""
minibatch = list(minibatch)
if not self.sequential:
return minibatch
if self.fix_length is None:
max_len = max(len(x) for x in minibatch)
else:
max_len = self.fix_length + (
self.init_token, self.eos_token).count(None) - 2
padded, lengths = [], []
for x in minibatch:
if self.pad_first:
padded.append(
[self.pad_token] * max(0, max_len - len(x))
+ ([] if self.init_token is None else [self.init_token])
+ list(x[-max_len:] if self.truncate_first else x[:max_len])
+ ([] if self.eos_token is None else [self.eos_token]))
else:
padded.append(
([] if self.init_token is None else [self.init_token])
+ list(x[-max_len:] if self.truncate_first else x[:max_len])
+ ([] if self.eos_token is None else [self.eos_token])
+ [self.pad_token] * max(0, max_len - len(x)))
lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
if self.include_lengths:
return (padded, lengths)
return padded
def build_vocab(self, *args, **kwargs):
"""Construct the Vocab object for this field from one or more datasets.
Arguments:
Positional arguments: Dataset objects or other iterable data
sources from which to construct the Vocab object that
represents the set of possible values for this field. If
a Dataset object is provided, all columns corresponding
to this field are used; individual columns can also be
provided directly.
Remaining keyword arguments: Passed to the constructor of Vocab.
"""
counter = Counter()
sources = []
for arg in args:
if isinstance(arg, Dataset):
sources += [getattr(arg, name) for name, field in
arg.fields.items() if field is self]
else:
sources.append(arg)
for data in sources:
for x in data:
if not self.sequential:
x = [x]
try:
counter.update(x)
except TypeError:
counter.update(chain.from_iterable(x))
specials = list(OrderedDict.fromkeys(
tok for tok in [self.unk_token, self.pad_token, self.init_token,
self.eos_token] + kwargs.pop('specials', [])
if tok is not None))
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
def numericalize(self, arr, device=None):
"""Turn a batch of examples that use this field into a Variable.
If the field has include_lengths=True, a tensor of lengths will be
included in the return value.
Arguments:
arr (List[List[str]], or tuple of (List[List[str]], List[int])):
List of tokenized and padded examples, or tuple of List of
tokenized and padded examples and List of lengths of each
example if self.include_lengths is True.
device (str or torch.device): A string or instance of `torch.device`
specifying which device the Variables are going to be created on.
If left as default, the tensors will be created on cpu. Default: None.
"""
if self.include_lengths and not isinstance(arr, tuple):
raise ValueError("Field has include_lengths set to True, but "
"input data is not a tuple of "
"(data batch, batch lengths).")
if isinstance(arr, tuple):
arr, lengths = arr
lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
if self.use_vocab:
if self.sequential:
arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
else:
arr = [self.vocab.stoi[x] for x in arr]
if self.postprocessing is not None:
arr = self.postprocessing(arr, self.vocab)
else:
if self.dtype not in self.dtypes:
raise ValueError(
"Specified Field dtype {} can not be used with "
"use_vocab=False because we do not know how to numericalize it. "
"Please raise an issue at "
"https://github.com/pytorch/text/issues".format(self.dtype))
numericalization_func = self.dtypes[self.dtype]
# It doesn't make sense to explicitly coerce to a numeric type if
# the data is sequential, since it's unclear how to coerce padding tokens
# to a numeric type.
if not self.sequential:
arr = [numericalization_func(x) if isinstance(x, six.string_types)
else x for x in arr]
if self.postprocessing is not None:
arr = self.postprocessing(arr, None)
var = torch.tensor(arr, dtype=self.dtype, device=device)
if self.sequential and not self.batch_first:
var.t_()
if self.sequential:
var = var.contiguous()
if self.include_lengths:
return var, lengths
return var
文章浏览阅读1.2k次,点赞35次,收藏18次。AutowiredPostConstruct 注释用于在依赖关系注入完成之后需要执行的方法上,以执行任何初始化。此方法必须在将类放入服务之前调用。支持依赖关系注入的所有类都必须支持此注释。即使类没有请求注入任何资源,用 PostConstruct 注释的方法也必须被调用。只有一个方法可以用此注释进行注释。_springboot2.7获取bean
文章浏览阅读2.1k次。理论介绍 节点定义package logistic;public class Instance { public int label; public double[] x; public Instance(){} public Instance(int label,double[] x){ this.label = label; th_logisticregression java
文章浏览阅读981次,点赞21次,收藏18次。本书是获得了很多读者好评的Linux经典畅销书**《Linux从入门到精通》的第2版**。下面我们来进行文件的恢复,执行下文中的lsof命令,在其返回结果中我们可以看到test-recovery.txt (deleted)被删除了,但是其存在一个进程tail使用它,tail进程的进程编号是1535。我们看到文件名为3的文件,就是我们刚刚“误删除”的文件,所以我们使用下面的cp命令把它恢复回去。命令进入该进程的文件目录下,1535是tail进程的进程id,这个文件目录里包含了若干该进程正在打开使用的文件。
文章浏览阅读10w+次,点赞12次,收藏72次。RTMP(Real Time Messaging Protocol)实时消息传输协议是Adobe公司提出得一种媒体流传输协议,其提供了一个双向得通道消息服务,意图在通信端之间传递带有时间信息得视频、音频和数据消息流,其通过对不同类型得消息分配不同得优先级,进而在网传能力限制下确定各种消息得传输次序。_rtmp
文章浏览阅读64次。2017年12月的计算机等级考试将要来临!出国留学网为考生们整理了2017年12月计算机一级MSOffice考试习题,希望能帮到大家,想了解更多计算机等级考试消息,请关注我们,我们会第一时间更新。2017年12月计算机一级MSOffice考试习题(二)一、单选题1). 计算机最主要的工作特点是( )。A.存储程序与自动控制B.高速度与高精度C.可靠性与可用性D.有记忆能力正确答案:A答案解析:计算...
文章浏览阅读356次。在学MYSQL的时候刚刚好看到了这个提权,很久之前用过别人现成的,但是一直时间没去细想, 这次就自己复现学习下。 0x00 UDF 什么是UDF? UDF (user defined function),即用户自定义函数。是通过添加新函数,对MySQL的功能进行扩充,就像使..._the provided input file '/usr/share/metasploit-framework/data/exploits/mysql
文章浏览阅读3.1w次,点赞71次,收藏485次。webService一 WebService概述1.1 WebService是什么WebService是一种跨编程语言和跨操作系统平台的远程调用技术。Web service是一个平台独立的,低耦合的,自包含的、基于可编程的web的应用程序,可使用开放的XML(标准通用标记语言下的一个子集)标准...
文章浏览阅读1w次。前言照例给出官网:Retrofit官网其实大家学习的时候,完全可以按照官网Introduction,自己写一个例子来运行。但是百密一疏,官网可能忘记添加了一句非常重要的话,导致你可能出现如下错误:Could not locate ResponseBody converter错误信息:Caused by: java.lang.IllegalArgumentException: Could not l_已添加addconverterfactory 但是 could not locate responsebody converter
文章浏览阅读1k次。一套键鼠控制Windows+Linux——Synergy在Windows10和Ubuntu18.04共控的实践Synergy简介准备工作(重要)Windows服务端配置Ubuntu客户端配置配置开机启动Synergy简介Synergy能够通过IP地址实现一套键鼠对多系统、多终端进行控制,免去了对不同终端操作时频繁切换键鼠的麻烦,可跨平台使用,拥有Linux、MacOS、Windows多个版本。Synergy应用分服务端和客户端,服务端即主控端,Synergy会共享连接服务端的键鼠给客户端终端使用。本文_linux 18.04 synergy
文章浏览阅读374次。写demo的时候遇到了很多问题,记录一下。安装nacos1.4.0配置mysql数据库,新建nacos_config数据库,并根据初始化脚本新建表,使配置从数据库读取,可单机模式启动也可以集群模式启动,启动时 ./start.sh -m standaloneapplication.properties 主要是db部分配置## Copyright 1999-2018 Alibaba Group Holding Ltd.## Licensed under the Apache License,_seata1.4.0 +nacos 集成
文章浏览阅读833次。iperf使用方法详解 iperf3是一款带宽测试工具,它支持调节各种参数,比如通信协议,数据包个数,发送持续时间,测试完会报告网络带宽,丢包率和其他参数。 安装 sudo apt-get install iperf3 iPerf3常用的参数: -c :指定客户端模式。例如:iperf3 -c 192.168.1.100。这将使用客户端模式连接到IP地址为192.16..._iperf客户端指定ip地址
文章浏览阅读7.4k次。 写这个函数目的不是为了和C/C++库中的函数在性能和安全性上一比高低,只是为了给那些喜欢探讨函数内部实现的网友,提供一种从浮点性到字符串转换的一种途径。 浮点数是有精度限制的,所以即使我们在使用C/C++中的sprintf或者cout 限制,当然这个精度限制是可以修改的。比方在C++中,我们可以cout.precision(10),不过这样设置的整个输出字符长度为10,而不是特定的小数点后1_c++浮点数 转 字符串 精度损失最小